Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/sdist/amici/_symbolic/de_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1595,6 +1595,11 @@ def _compute_equation(self, name: str) -> None:

self._eqs[name] = event_eqs

elif name == "x_old":
self._eqs[name] = sp.Matrix(
[state.get_x_rdata() for state in self.states()]
)

elif name == "z":
event_observables = [
sp.zeros(self.num_eventobs(), 1) for _ in self._events
Expand Down
15 changes: 11 additions & 4 deletions python/sdist/amici/importers/petab/_petab_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from amici._symbolic import DEModel, Event
from amici.importers.utils import MeasurementChannel, amici_time_symbol
from amici.logging import get_logger
from amici.jax.petab import JAXProblem

from .v1.sbml_import import _add_global_parameter

Expand Down Expand Up @@ -151,10 +152,6 @@ def __init__(
"PEtab v2 importer currently only supports SBML and PySB "
f"models. Got {self.petab_problem.model.type_id!r}."
)
if jax:
raise NotImplementedError(
"PEtab v2 importer currently does not support JAX. "
)

if self._debug:
print("PetabImpoter.__init__: petab_problem:")
Expand Down Expand Up @@ -577,6 +574,11 @@ def import_module(self, force_import: bool = False) -> amici.ModelModule:
else:
self._do_import_pysb()

if self._jax:
return amici.import_model_module(
Path(self.output_dir).stem, Path(self.output_dir).parent
)

return amici.import_model_module(
self._module_name,
self.output_dir,
Expand All @@ -601,6 +603,11 @@ def create_simulator(
"""
from amici.sim.sundials.petab import ExperimentManager, PetabSimulator

if self._jax:
model_module = self.import_module(force_import=force_import)
model = model_module.Model()
return JAXProblem(model, self.petab_problem)

model = self.import_module(force_import=force_import).get_model()
em = ExperimentManager(model=model, petab_problem=self.petab_problem)
return PetabSimulator(em=em)
Expand Down
10 changes: 6 additions & 4 deletions python/sdist/amici/importers/petab/v1/parameter_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def create_parameter_mapping(
converter_config = (
libsbml.SBMLLocalParameterConverter().getDefaultProperties()
)
petab_problem.sbml_document.convert(converter_config)
petab_problem.model.sbml_document.convert(converter_config)
else:
logger.debug(
"No petab_problem.sbml_document is set. Cannot "
Expand Down Expand Up @@ -474,9 +474,11 @@ def create_parameter_mapping_for_condition(
# ExpData.x0, but in the case of pre-equilibration this would not allow for
# resetting initial states.

if states_in_condition_table := get_states_in_condition_table(
states_in_condition_table = get_states_in_condition_table(
petab_problem, condition
):
)

if states_in_condition_table:
# set indicator fixed parameter for preeq
# (we expect here, that this parameter was added during import and
# that it was not added by the user with a different meaning...)
Expand Down Expand Up @@ -525,7 +527,7 @@ def create_parameter_mapping_for_condition(
value,
fill_fixed_parameters=fill_fixed_parameters,
)
# set dummy value as above
# set dummy value as above
if condition_map_preeq:
condition_map_preeq[init_par_id] = 0.0
condition_scale_map_preeq[init_par_id] = LIN
Expand Down
7 changes: 6 additions & 1 deletion python/sdist/amici/importers/petab/v1/sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import re
from _collections import OrderedDict
from itertools import chain
import pandas as pd
from pathlib import Path

import libsbml
import petab.v1 as petab
import petab.v2 as petabv2
import sympy as sp
from petab.v1.models import MODEL_TYPE_SBML
from sympy.abc import _clash
Expand Down Expand Up @@ -304,7 +306,10 @@ def import_model_sbml(

if validate:
logger.info("Validating PEtab problem ...")
petab.lint_problem(petab_problem)
if isinstance(petab_problem, petabv2.Problem):
petabv2.lint_problem(petab_problem)
else:
petab.lint_problem(petab_problem)

# Model name from SBML ID or filename
if model_name is None:
Expand Down
14 changes: 12 additions & 2 deletions python/sdist/amici/jax/_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def eq(
tcl: jt.Float[jt.Array, "ncl"],
h0: jt.Float[jt.Array, "ne"],
x0: jt.Float[jt.Array, "nxs"],
h_mask: jt.Bool[jt.Array, "ne"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
root_finder: AbstractRootFinder,
Expand Down Expand Up @@ -147,6 +148,7 @@ def body_fn(carry):
term,
root_cond_fn,
delta_x,
h_mask,
stats,
)

Expand All @@ -172,10 +174,12 @@ def body_fn(carry):

def solve(
p: jt.Float[jt.Array, "np"],
t0: jnp.float_,
ts: jt.Float[jt.Array, "nt_dyn"],
tcl: jt.Float[jt.Array, "ncl"],
h: jt.Float[jt.Array, "ne"],
x0: jt.Float[jt.Array, "nxs"],
h_mask: jt.Bool[jt.Array, "ne"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
root_finder: AbstractRootFinder,
Expand All @@ -192,6 +196,8 @@ def solve(

:param p:
parameters
:param t0:
initial time point
:param ts:
time points at which solutions are evaluated
:param tcl:
Expand Down Expand Up @@ -223,7 +229,7 @@ def solve(
if not root_cond_fns:
# no events, we can just run a single segment
sol, _, stats = _run_segment(
0.0,
t0,
ts[-1],
x0,
p,
Expand Down Expand Up @@ -301,6 +307,7 @@ def body_fn(carry):
term,
root_cond_fn,
delta_x,
h_mask,
stats,
)

Expand All @@ -315,7 +322,7 @@ def body_fn(carry):
body_fn,
(
jnp.zeros((ts.shape[0], x0.shape[0]), dtype=x0.dtype) + x0,
0.0,
t0,
x0,
jnp.zeros((ts.shape[0], h.shape[0]), dtype=h.dtype),
h,
Expand Down Expand Up @@ -419,6 +426,7 @@ def _handle_event(
term: diffrax.ODETerm,
root_cond_fn: Callable,
delta_x: Callable,
h_mask: jt.Bool[jt.Array, "ne"],
stats: dict,
):
args = (p, tcl, h)
Expand Down Expand Up @@ -446,6 +454,8 @@ def _handle_event(
delta_x,
)

h_next = jnp.where(h_mask, h_next, h)

if os.getenv("JAX_DEBUG") == "1":
jax.debug.print(
"rootvals: {}, roots_found: {}, roots_dir: {}, h: {}, h_next: {}",
Expand Down
36 changes: 33 additions & 3 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,17 +598,28 @@ def simulate_condition_unjitted(
init_override: jt.Float[jt.Array, "*nx"] = jnp.array([]),
init_override_mask: jt.Bool[jt.Array, "*nx"] = jnp.array([]),
ts_mask: jt.Bool[jt.Array, "nt"] = jnp.array([]),
h_mask: jt.Bool[jt.Array, "ne"] = jnp.array([]),
t_zero: jnp.float_ = 0.0,
ret: ReturnValue = ReturnValue.llh,
) -> tuple[jt.Float[jt.Array, "*nt"], dict]:
"""
Unjitted version of simulate_condition.

See :meth:`simulate_condition` for full documentation.
"""
t0 = 0.0
t0 = t_zero
if p is None:
p = self.parameters

if os.getenv("JAX_DEBUG") == "1":
jax.debug.print(
"x_reinit: {}, x_preeq: {}, x_def: {}. p: {}",
x_reinit,
x_preeq,
self._x0(t0, p),
p,
)

if x_preeq.shape[0]:
x = x_preeq
elif init_override.shape[0]:
Expand All @@ -625,6 +636,7 @@ def simulate_condition_unjitted(
# Re-initialization
if x_reinit.shape[0]:
x = jnp.where(mask_reinit, x_reinit, x)

x_solver = self._x_solver(x)
tcl = self._tcl(x, p)

Expand All @@ -636,17 +648,20 @@ def simulate_condition_unjitted(
root_finder,
self._root_cond_fn,
self._delta_x,
h_mask,
{},
)

# Dynamic simulation
if ts_dyn.shape[0]:
x_dyn, h_dyn, stats_dyn = solve(
p,
t0,
ts_dyn,
tcl,
h,
x_solver,
h_mask,
solver,
controller,
root_finder,
Expand All @@ -671,6 +686,7 @@ def simulate_condition_unjitted(
tcl,
h,
x_solver,
h_mask,
solver,
controller,
root_finder,
Expand Down Expand Up @@ -776,6 +792,8 @@ def simulate_condition(
init_override: jt.Float[jt.Array, "*nx"] = jnp.array([]),
init_override_mask: jt.Bool[jt.Array, "*nx"] = jnp.array([]),
ts_mask: jt.Bool[jt.Array, "nt"] = jnp.array([]),
h_mask: jt.Bool[jt.Array, "ne"] = jnp.array([]),
t_zero: jnp.float_ = 0.0,
ret: ReturnValue = ReturnValue.llh,
) -> tuple[jt.Float[jt.Array, "*nt"], dict]:
r"""
Expand Down Expand Up @@ -828,6 +846,9 @@ def simulate_condition(
:param ts_mask:
mask to remove (padded) time points. If `True`, the corresponding time point is used for the evaluation of
the output. Only applied if ret is ReturnValue.llh, ReturnValue.nllhs, ReturnValue.res, or ReturnValue.chi2.
:param h_mask:
mask for heaviside variables. If `True`, the corresponding heaviside variable is updated during simulation, otherwise it
it marked as 1.0.
:param ret:
which output to return. See :class:`ReturnValue` for available options.
:return:
Expand All @@ -854,6 +875,8 @@ def simulate_condition(
init_override,
init_override_mask,
ts_mask,
h_mask,
t_zero,
ret,
)

Expand All @@ -863,6 +886,7 @@ def preequilibrate_condition(
p: jt.Float[jt.Array, "np"] | None,
x_reinit: jt.Float[jt.Array, "*nx"],
mask_reinit: jt.Bool[jt.Array, "*nx"],
h_mask: jt.Bool[jt.Array, "ne"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
root_finder: AbstractRootFinder,
Expand Down Expand Up @@ -910,6 +934,7 @@ def preequilibrate_condition(
root_finder,
self._root_cond_fn,
self._delta_x,
h_mask,
{},
)

Expand All @@ -918,6 +943,7 @@ def preequilibrate_condition(
tcl,
h,
current_x,
h_mask,
solver,
controller,
root_finder,
Expand All @@ -941,10 +967,12 @@ def _handle_t0_event(
root_finder: AbstractRootFinder,
root_cond_fn: Callable,
delta_x: Callable,
h_mask: jt.Bool[jt.Array, "ne"],
stats: dict,
):
y0 = y0_next.copy()
rf0 = self.event_initial_values - 0.5
h = jnp.heaviside(rf0, 0.0)
h = jnp.where(h_mask, jnp.heaviside(rf0, 0.0), jnp.ones_like(rf0))
args = (p, tcl, h)
rfx = root_cond_fn(t0_next, y0_next, args)
roots_dir = jnp.sign(rfx - rf0)
Expand Down Expand Up @@ -979,13 +1007,15 @@ def _handle_t0_event(

if os.getenv("JAX_DEBUG") == "1":
jax.debug.print(
"h: {}, rf0: {}, rfx: {}, roots_found: {}, roots_dir: {}, h_next: {}",
"handle_t0_event h: {}, rf0: {}, rfx: {}, roots_found: {}, roots_dir: {}, h_next: {}, y0_next: {}, y0: {}",
h,
rf0,
rfx,
roots_found,
roots_dir,
h_next,
y0_next,
y0,
)

return y0_next, t0_next, h_next, stats
Expand Down
13 changes: 9 additions & 4 deletions python/sdist/amici/jax/ode_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,21 @@ def __init__(
raise NotImplementedError(
"The JAX backend does not support models with algebraic states."
)

# if not ode_model.has_only_time_dependent_event_assignments():
# raise NotImplementedError(
# "The JAX backend does not support event assignments with explicit non-time dependent triggers."
# )

if ode_model.has_priority_events():
raise NotImplementedError(
"The JAX backend does not support event priorities."
)

if ode_model.has_implicit_event_assignments():
raise NotImplementedError(
"The JAX backend does not support event assignments with implicit triggers."
)
# if ode_model.has_implicit_event_assignments():
# raise NotImplementedError(
# "The JAX backend does not support event assignments with implicit triggers."
# )

self.verbose: bool = logger.getEffectiveLevel() <= logging.DEBUG

Expand Down
Loading
Loading