diff --git a/python/sdist/amici/_symbolic/de_model.py b/python/sdist/amici/_symbolic/de_model.py index 3ddf30fba1..fc027d20c8 100644 --- a/python/sdist/amici/_symbolic/de_model.py +++ b/python/sdist/amici/_symbolic/de_model.py @@ -1141,6 +1141,8 @@ def _generate_symbol(self, name: str) -> None: ] ) return + elif name == "deltax": + length = sp.Matrix(self.eq(name)).shape[0] else: length = len(self.eq(name)) self._syms[name] = sp.Matrix( @@ -2623,30 +2625,30 @@ def _process_hybridization(self, hybridization: dict) -> None: if added_expressions: self.toposort_expressions() - def get_explicit_roots(self) -> set[sp.Expr]: + def get_explicit_roots(self) -> list[sp.Expr]: """ Returns explicit formulas for all discontinuities (events) that can be precomputed :return: - set of symbolic roots + list of symbolic roots """ - return {root for e in self._events for root in e.get_trigger_times()} + return [root for e in self._events for root in e.get_trigger_times()] - def get_implicit_roots(self) -> set[sp.Expr]: + def get_implicit_roots(self) -> list[sp.Expr]: """ Returns implicit equations for all discontinuities (events) that have to be located via rootfinding :return: - set of symbolic roots + list of symbolic roots """ - return { + return [ e.get_val() for e in self._events if not e.has_explicit_trigger_times() - } - + ] + def has_algebraic_states(self) -> bool: """ Checks whether the model has algebraic states @@ -2664,6 +2666,24 @@ def has_event_assignments(self) -> bool: boolean indicating if event assignments are present """ return any(event.updates_state for event in self._events) + + def has_priority_events(self) -> bool: + """ + Checks whether the model has events with priorities defined + + :return: + boolean indicating if priority events are present + """ + return any(event.get_priority() is not None for event in self._events) + + def has_implicit_event_assignments(self) -> bool: + """ + Checks whether the model has event assignments with implicit triggers + + :return: + boolean indicating if event assignments with implicit triggers are present + """ + return any(event.updates_state and not event.has_explicit_trigger_times({}) for event in self._events) def toposort_expressions( self, reorder: bool = True diff --git a/python/sdist/amici/importers/petab/_petab_importer.py b/python/sdist/amici/importers/petab/_petab_importer.py index 89b4c6981d..fd1a4cb45e 100644 --- a/python/sdist/amici/importers/petab/_petab_importer.py +++ b/python/sdist/amici/importers/petab/_petab_importer.py @@ -336,7 +336,7 @@ def _do_import_sbml(self): show_model_info(self.petab_problem.model.sbml_model) sbml_importer = amici.SbmlImporter( - self.petab_problem.model.sbml_model, + self.petab_problem.model.sbml_model, jax=self._jax ) self._check_placeholders() diff --git a/python/sdist/amici/importers/petab/v1/sbml_import.py b/python/sdist/amici/importers/petab/v1/sbml_import.py index d49593c05d..f1953f9592 100644 --- a/python/sdist/amici/importers/petab/v1/sbml_import.py +++ b/python/sdist/amici/importers/petab/v1/sbml_import.py @@ -332,6 +332,7 @@ def import_model_sbml( sbml_importer = amici.SbmlImporter( sbml_model, discard_annotations=discard_sbml_annotations, + jax=jax, ) sbml_model = sbml_importer.sbml_model diff --git a/python/sdist/amici/importers/sbml/__init__.py b/python/sdist/amici/importers/sbml/__init__.py index 68e77bc017..771f9f0cb4 100644 --- a/python/sdist/amici/importers/sbml/__init__.py +++ b/python/sdist/amici/importers/sbml/__init__.py @@ -101,6 +101,7 @@ def __init__( show_sbml_warnings: bool = False, from_file: bool = True, discard_annotations: bool = False, + jax: bool = False, ) -> None: """ Initialize. @@ -188,6 +189,7 @@ def __init__( ignore_units=True, evaluate=True, ) + self.jax = jax @log_execution_time("loading SBML", logger) def _process_document(self) -> None: @@ -1893,6 +1895,21 @@ def _process_events(self) -> None: "priority": self._sympify(event.getPriority()), } + if self.jax: + # Add a negative event for JAX models to handle + # TODO: remove once condition function directions can be + # traced through diffrax solve + neg_event_id = event_id + "_negative" + neg_event_sym = sp.Symbol(neg_event_id) + self._symbols[SymbolId.EVENT][neg_event_sym] = { + "name": neg_event_id, + "value": -trigger, + "assignments": None, + "initial_value": not initial_value, + "use_values_from_trigger_time": use_trig_val, + "priority": self._sympify(event.getPriority()), + } + @log_execution_time("processing observation model", logger) def _process_observation_model( self, diff --git a/python/sdist/amici/jax/_simulation.py b/python/sdist/amici/jax/_simulation.py index 605791d33e..7b19e61517 100644 --- a/python/sdist/amici/jax/_simulation.py +++ b/python/sdist/amici/jax/_simulation.py @@ -33,6 +33,7 @@ def eq( term: diffrax.ODETerm, root_cond_fns: list[Callable], root_cond_fn: Callable, + delta_x: Callable, known_discs: jt.Float[jt.Array, "*nediscs"], max_steps: jnp.int_, ) -> tuple[jt.Float[jt.Array, "nxs"], jt.Float[jt.Array, "ne"], dict]: @@ -61,6 +62,8 @@ def eq( list of individual root condition functions for discontinuities :param root_cond_fn: root condition function for all discontinuities + :param delta_x: + function to compute state changes at events :param known_discs: known discontinuities, used to clip the step size controller :param max_steps: @@ -86,7 +89,6 @@ def eq( [None], diffrax.SaveAt(t1=True), term, - known_discs, dict(**STARTING_STATS), ) y1 = jnp.where( @@ -123,7 +125,6 @@ def body_fn(carry): [None] + [True] * len(root_cond_fns), diffrax.SaveAt(t1=True), term, - known_discs, stats, ) y0_next = jnp.where( @@ -136,19 +137,16 @@ def body_fn(carry): ) t0_next = jnp.where(jnp.isfinite(sol.ts), sol.ts, -jnp.inf).max() - y0_next, t0_next, h_next, stats = _handle_event( + y0_next, h_next, stats = _handle_event( t0_next, - jnp.inf, y0_next, p, tcl, h, - solver, - controller, root_finder, - diffrax.DirectAdjoint(), term, root_cond_fn, + delta_x, stats, ) @@ -186,6 +184,7 @@ def solve( term: diffrax.ODETerm, root_cond_fns: list[Callable], root_cond_fn: Callable, + delta_x: Callable, known_discs: jt.Float[jt.Array, "*nediscs"], ) -> tuple[jt.Float[jt.Array, "nt nxs"], jt.Float[jt.Array, "nt ne"], dict]: """ @@ -213,6 +212,8 @@ def solve( list of individual root condition functions for discontinuities :param root_cond_fn: root condition function for all discontinuities + :param delta_x: + function to compute state changes at events :param known_discs: known discontinuities, used to clip the step size controller :return: @@ -237,7 +238,6 @@ def solve( [], diffrax.SaveAt(ts=ts), term, - known_discs, dict(**STARTING_STATS), ) return sol.ys, jnp.repeat(h[None, :], sol.ys.shape[0]), stats @@ -254,7 +254,7 @@ def cond_fn(carry): def body_fn(carry): ys, t_start, y0, hs, h, stats = carry - sol, idx, stats = _run_segment( + sol, _, stats = _run_segment( t_start, ts[-1], y0, @@ -277,7 +277,6 @@ def body_fn(carry): ] ), term, - known_discs, stats, ) # update the solution for all timepoints in the simulated segment @@ -291,28 +290,22 @@ def body_fn(carry): y0_next = sol.ys[1][ -1 ] # next initial state is the last state of the current segment - ts_next = jnp.where( - ts > t0_next, ts, ts[-1] - ).min() # timepoint of next datapoint, don't step over that - y0_next, t0_next, h_next, stats = _handle_event( + y0_next, h_next, stats = _handle_event( t0_next, - ts_next, y0_next, p, tcl, h, - solver, - controller, root_finder, - adjoint, term, root_cond_fn, + delta_x, stats, ) - was_event = jnp.isin(ts, sol.ts[1]) - hs = jnp.where(was_event[:, None], h_next[None, :], hs) + after_event = sol.ts[1] < ts + hs = jnp.where(after_event[:, None], h_next[None, :], hs) return ys, t0_next, y0_next, hs, h_next, stats @@ -351,7 +344,6 @@ def _run_segment( cond_dirs: list[None | bool], saveat: diffrax.SaveAt, term: diffrax.ODETerm, - known_discs: jt.Float[jt.Array, "*nediscs"], stats: dict, ) -> tuple[diffrax.Solution, int, dict]: """Solve a single integration segment and return triggered event index, start time for the next segment, @@ -373,16 +365,6 @@ def _run_segment( else None ) - # manage events with explicit discontinuities - controller = ( - diffrax.ClipStepSizeController( - controller, - jump_ts=known_discs, - ) - if known_discs.size - else controller - ) - sol = diffrax.diffeqsolve( term, solver, @@ -429,17 +411,14 @@ def _run_segment( def _handle_event( t0_next: float, - t_max: float, y0_next: jt.Float[jt.Array, "nxs"], p: jt.Float[jt.Array, "np"], tcl: jt.Float[jt.Array, "ncl"], h: jt.Float[jt.Array, "ne"], - solver: diffrax.AbstractSolver, - controller: diffrax.AbstractStepSizeController, root_finder: AbstractRootFinder, - adjoint: diffrax.AbstractAdjoint, term: diffrax.ODETerm, root_cond_fn: Callable, + delta_x: Callable, stats: dict, ): args = (p, tcl, h) @@ -457,11 +436,15 @@ def _handle_event( ) roots_dir = jnp.sign(droot_dt) # direction of the root condition function - h_next = h + jnp.where( + y0_next, h_next = _apply_event_assignments( roots_found, roots_dir, - jnp.zeros_like(h), - ) # update heaviside variables based on the root condition function + y0_next, + p, + tcl, + h, + delta_x, + ) if os.getenv("JAX_DEBUG") == "1": jax.debug.print( @@ -472,4 +455,34 @@ def _handle_event( h, h_next, ) - return y0_next, t0_next, h_next, stats + + return y0_next, h_next, stats + +def _apply_event_assignments( + roots_found, + roots_dir, + y0_next, + p, + tcl, + h, + delta_x, +): + h_next = jnp.where( + roots_found, + jnp.logical_not(h), + h, + ) # update heaviside variables based on the root condition function + + mask = jnp.array( + [ + (roots_found & (roots_dir > 0.0) & (h == 0.0)) + for _ in range(y0_next.shape[0]) + ] + ).T + delx = delta_x(y0_next, p, tcl) + if y0_next.size: + delx = delx.reshape(delx.size // y0_next.shape[0], y0_next.shape[0],) + y0_up = jnp.where(mask, delx, 0.0) + y0_next = y0_next + jnp.sum(y0_up, axis=0) + + return y0_next, h_next diff --git a/python/sdist/amici/jax/jax.template.py b/python/sdist/amici/jax/jax.template.py index b5247d2eab..fe0ff12d8d 100644 --- a/python/sdist/amici/jax/jax.template.py +++ b/python/sdist/amici/jax/jax.template.py @@ -124,29 +124,28 @@ def _root_cond_fn(self, t, y, args, **_): TPL_W_SYMS = self._w(t, y, p, tcl, h) TPL_IROOT_EQ + TPL_EROOT_EQ - return TPL_IROOT_RET - - def _root_cond_fn_event(self, ie, t, y, args, **_): - """ - Root condition function for a specific event index. - """ - __, __, h = args - rval = self._root_cond_fn(t, y, args, **_) - # only allow root triggers where trigger function is negative (heaviside == 0) - masked_rval = jnp.where(h == 0.0, rval, 1.0) - return masked_rval.at[ie].get() - - def _root_cond_fns(self): - """Return root condition functions for discontinuities.""" - return [ - eqx.Partial(self._root_cond_fn_event, ie) - for ie in range(self.n_events) - ] + return jnp.hstack((TPL_IROOT_RET, TPL_EROOT_RET)) + + def _delta_x(self, y, p, tcl): + TPL_X_SYMS = y + TPL_P_SYMS = p + TPL_TCL_SYMS = tcl + # FIXME: workaround until state from event time is properly passed + TPL_X_OLD_SYMS = y + + TPL_DELTAX_EQ + + return TPL_DELTAX_RET + + @property + def event_initial_values(self): + return TPL_EVENT_INITIAL_VALUES @property def n_events(self): - return TPL_N_IEVENTS + return TPL_N_IEVENTS + TPL_N_EEVENTS @property def observable_ids(self): diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index c8a3f56014..7a1636c651 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -15,7 +15,9 @@ import jaxtyping as jt from optimistix import AbstractRootFinder -from ._simulation import eq, solve +import os + +from ._simulation import eq, solve, _apply_event_assignments class ReturnValue(enum.Enum): @@ -267,22 +269,6 @@ def _known_discs( """ ... - @abstractmethod - def _root_cond_fns( - self, - ) -> list[Callable[[float, jt.Float[jt.Array, "nxs"], tuple], jt.Float]]: - """Return condition functions for implicit discontinuities. - - These functions are passed to :class:`diffrax.Event` and must evaluate - to zero when a discontinuity is triggered. - - :param p: - model parameters - :return: - tuple of callable root functions - """ - ... - @abstractmethod def _root_cond_fn( self, @@ -308,6 +294,20 @@ def _root_cond_fn( """ ... + @abstractmethod + def _delta_x( + self, y: jt.Float[jt.Array, "nxs"] + ) -> jt.Float[jt.Array, "nxs"]: + """ + Compute the state vector changes at discontinuities. + + :param y: + state vector + :return: + changes in the state vector at discontinuities + """ + ... + @property @abstractmethod def n_events(self) -> int: @@ -362,6 +362,50 @@ def expression_ids(self) -> list[str]: """ ... + def _root_cond_fn_event( + self, + ie: int, + t: float, + y: jt.Float[jt.Array, "nxs"], + args: tuple, + **_ + ): + """ + Root condition function for a specific event index. + + :param ie: + event index + :param t: + time point + :param y: + state vector + :param args: + tuple of arguments required for _root_cond_fn + :return: + mask of root condition value for the specified event index + """ + __, __, h = args + rval = self._root_cond_fn(t, y, args, **_) + # only allow root triggers where trigger function is negative (heaviside == 0) + masked_rval = jnp.where(h == 0.0, rval, 1.0) + return masked_rval.at[ie].get() + + def _root_cond_fns(self) -> list[Callable[[float, jt.Float[jt.Array, "nxs"], tuple], jt.Float]]: + """Return condition functions for implicit discontinuities. + + These functions are passed to :class:`diffrax.Event` and must evaluate + to zero when a discontinuity is triggered. + + :param p: + model parameters + :return: + iterable of callable root functions + """ + return [ + eqx.Partial(self._root_cond_fn_event, ie) + for ie in range(self.n_events) + ] + def _initialise_heaviside_variables( self, t0: jt.Float[jt.Scalar, ""], @@ -383,10 +427,17 @@ def _initialise_heaviside_variables( :return: heaviside variables """ - h0 = jnp.zeros((self.n_events,)) # dummy values + h0 = self.event_initial_values.astype(float) + if os.getenv("JAX_DEBUG") == "1": + jax.debug.print( + "h0: {}", + h0, + ) roots_found = self._root_cond_fn(t0, x_solver, (p, tcl, h0)) return jnp.where( - roots_found >= 0.0, jnp.ones_like(h0), jnp.zeros_like(h0) + jnp.logical_and(roots_found >= 0.0, h0 == 1.0), + jnp.ones_like(h0), + jnp.zeros_like(h0) ) def _x_rdatas( @@ -576,7 +627,17 @@ def simulate_condition_unjitted( x = jnp.where(mask_reinit, x_reinit, x) x_solver = self._x_solver(x) tcl = self._tcl(x, p) - h = self._initialise_heaviside_variables(t0, x_solver, p, tcl) + + x_solver, _, h, _ = self._handle_t0_event( + t0, + x_solver, + p, + tcl, + root_finder, + self._root_cond_fn, + self._delta_x, + {}, + ) # Dynamic simulation if ts_dyn.shape[0]: @@ -594,6 +655,7 @@ def simulate_condition_unjitted( diffrax.ODETerm(self._xdot), self._root_cond_fns(), self._root_cond_fn, + self._delta_x, self._known_discs(p), ) x_solver = x_dyn[-1, :] @@ -616,6 +678,7 @@ def simulate_condition_unjitted( diffrax.ODETerm(self._xdot), self._root_cond_fns(), self._root_cond_fn, + self._delta_x, self._known_discs(p), max_steps, ) @@ -836,10 +899,20 @@ def preequilibrate_condition( if x_reinit.shape[0]: x0 = jnp.where(mask_reinit, x_reinit, x0) tcl = self._tcl(x0, p) - h = self._initialise_heaviside_variables( - t0, self._x_solver(x0), p, tcl - ) + current_x = self._x_solver(x0) + + current_x, _, h, _ = self._handle_t0_event( + t0, + self._x_solver(x0), + p, + tcl, + root_finder, + self._root_cond_fn, + self._delta_x, + {}, + ) + current_x, _, stats_preeq = eq( p, tcl, @@ -852,12 +925,70 @@ def preequilibrate_condition( diffrax.ODETerm(self._xdot), self._root_cond_fns(), self._root_cond_fn, + self._delta_x, self._known_discs(p), max_steps, ) return self._x_rdata(current_x, tcl), dict(stats_preeq=stats_preeq) + def _handle_t0_event( + self, + t0_next: float, + y0_next: jt.Float[jt.Array, "nxs"], + p: jt.Float[jt.Array, "np"], + tcl: jt.Float[jt.Array, "ncl"], + root_finder: AbstractRootFinder, + root_cond_fn: Callable, + delta_x: Callable, + stats: dict, + ): + rf0 = self.event_initial_values - 0.5 + h = jnp.heaviside(rf0, 0.0) + args = (p, tcl, h) + rfx = root_cond_fn(t0_next, y0_next, args) + roots_dir = jnp.sign(rfx - rf0) + roots_found = jnp.sign(rfx) != jnp.sign(rf0) + + y0_next, h_next = _apply_event_assignments( + roots_found, + roots_dir, + y0_next, + p, + tcl, + h, + delta_x, + ) + + roots_zero = jnp.isclose( + rfx, 0.0, atol=root_finder.atol, rtol=root_finder.rtol + ) + droot_dt = ( + # ∂root_cond_fn/∂t + jax.jacfwd(root_cond_fn, argnums=0)(t0_next, y0_next, args) + + + # ∂root_cond_fn/∂y * ∂y/∂t + jax.jacfwd(root_cond_fn, argnums=1)(t0_next, y0_next, args) + @ self._xdot(t0_next, y0_next, args) + ) + h_next = jnp.where( + roots_zero, + droot_dt >= 0.0, + h_next, + ) + + if os.getenv("JAX_DEBUG") == "1": + jax.debug.print( + "h: {}, rf0: {}, rfx: {}, roots_found: {}, roots_dir: {}, h_next: {}", + h, + rf0, + rfx, + roots_found, + roots_dir, + h_next, + ) + + return y0_next, t0_next, h_next, stats def safe_log(x: jnp.float_) -> jnp.float_: """ diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index f6225c7df4..29b6458957 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -60,7 +60,8 @@ def _jax_variable_equations( f"{eq_name.upper()}_EQ": "\n".join( code_printer._get_sym_lines( (s.name for s in model.sym(eq_name)), - model.eq(eq_name).subs(subs), + # sp.Matrix to support event assignments which are lists + sp.Matrix(model.eq(eq_name)).subs(subs), indent, ) )[indent:] # remove indent for first line @@ -76,7 +77,7 @@ def _jax_return_variables( f"{eq_name.upper()}_RET": _jnp_array_str( s.name for s in model.sym(eq_name) ) - if model.sym(eq_name) + if model.sym(eq_name) and sp.Matrix(model.eq(eq_name)).shape[0] else "jnp.array([])" for eq_name in eq_names } @@ -144,14 +145,19 @@ def __init__( """ set_log_level(logger, verbose) - if ode_model.has_event_assignments(): + if ode_model.has_algebraic_states(): raise NotImplementedError( - "The JAX backend does not support models with event assignments." + "The JAX backend does not support models with algebraic states." ) - if ode_model.has_algebraic_states(): + if ode_model.has_priority_events(): raise NotImplementedError( - "The JAX backend does not support models with algebraic states." + "The JAX backend does not support event priorities." + ) + + if ode_model.has_implicit_event_assignments(): + raise NotImplementedError( + "The JAX backend does not support event assignments with implicit triggers." ) self.verbose: bool = logger.getEffectiveLevel() <= logging.DEBUG @@ -200,7 +206,9 @@ def _generate_jax_code(self) -> None: "x_solver", "x_rdata", "total_cl", + "eroot", "iroot", + "deltax", ) sym_names = ( "p", @@ -215,6 +223,7 @@ def _generate_jax_code(self) -> None: "sigmay", "x_rdata", "iroot", + "x_old", ) indent = 8 @@ -252,12 +261,18 @@ def _generate_jax_code(self) -> None: "P_VALUES": _jnp_array_str(self.model.val("p")), "ROOTS": _jnp_array_str( { - root + _print_trigger_root(root) for e in self.model._events for root in e.get_trigger_times() } ), "N_IEVENTS": str(len(self.model.get_implicit_roots())), + "N_EEVENTS": str(len(self.model.get_explicit_roots())), + "EVENT_INITIAL_VALUES": _jnp_array_str( + [ + e.get_initial_value() for e in self.model._events + ] + ), **{ "MODEL_NAME": self.model_name, # keep track of the API version that the model was generated with so we @@ -333,3 +348,13 @@ def set_name(self, model_name: str) -> None: ) self.model_name = model_name + +def _print_trigger_root(root: sp.Expr) -> str: + """Convert a trigger root expression into a string representation. + + :param root: The trigger root expression. + :return: A string representation of the trigger root. + """ + if root.is_number: + return float(root) + return str(root).replace(" ", "") diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index 2d6e111b2c..838a9f8144 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -343,39 +343,47 @@ def test_time_dependent_discontinuity(tmp_path): sbml = antimony2sbml(ant_model) importer = SbmlImporter(sbml, from_file=False) - importer.sbml2jax("time_disc", output_dir=tmp_path) - - module = amici._module_from_path("time_disc", tmp_path / "__init__.py") - model = module.Model() - - p = jnp.array([1.0]) - x0_full = model._x0(0.0, p) - tcl = model._tcl(x0_full, p) - x0 = model._x_solver(x0_full) - ts = jnp.array([0.0, 1.0, 2.0]) - h = model._initialise_heaviside_variables(0.0, model._x_solver(x0), p, tcl) - - assert len(model._root_cond_fns()) > 0 - assert model._known_discs(p).size == 0 - - ys, _, _ = solve( - p, - ts, - tcl, - h, - x0, - diffrax.Tsit5(), - diffrax.PIDController(**DEFAULT_CONTROLLER_SETTINGS), - optimistix.Newton(atol=1e-8, rtol=1e-8), - 1000, - diffrax.DirectAdjoint(), - diffrax.ODETerm(model._xdot), - model._root_cond_fns(), - model._root_cond_fn, - model._known_discs(p), - ) - assert ys.shape[0] == ts.shape[0] + try: + importer.sbml2jax("time_disc", output_dir=tmp_path) + + module = amici._module_from_path("time_disc", tmp_path / "__init__.py") + model = module.Model() + + p = jnp.array([1.0]) + x0_full = model._x0(0.0, p) + tcl = model._tcl(x0_full, p) + x0 = model._x_solver(x0_full) + ts = jnp.array([0.0, 1.0, 2.0]) + h = model._initialise_heaviside_variables(0.0, model._x_solver(x0), p, tcl) + + assert len(model._root_cond_fns()) > 0 + assert model._known_discs(p).size == 0 + + ys, _, _ = solve( + p, + ts, + tcl, + h, + x0, + diffrax.Tsit5(), + diffrax.PIDController(**DEFAULT_CONTROLLER_SETTINGS), + optimistix.Newton(atol=1e-8, rtol=1e-8), + 1000, + diffrax.DirectAdjoint(), + diffrax.ODETerm(model._xdot), + model._root_cond_fns(), + model._root_cond_fn, + model._delta_x, + model._known_discs(p), + ) + + assert ys.shape[0] == ts.shape[0] + + except NotImplementedError as err: + if "The JAX backend does not support" in str(err): + pytest.skip(str(err)) + raise err @skip_on_valgrind @@ -396,34 +404,41 @@ def test_time_dependent_discontinuity_equilibration(tmp_path): sbml = antimony2sbml(ant_model) importer = SbmlImporter(sbml, from_file=False) - importer.sbml2jax("time_disc_eq", output_dir=tmp_path) - - module = amici._module_from_path("time_disc_eq", tmp_path / "__init__.py") - model = module.Model() - - p = jnp.array([1.0]) - x0_full = model._x0(0.0, p) - tcl = model._tcl(x0_full, p) - x0 = model._x_solver(x0_full) - h = model._initialise_heaviside_variables(0.0, model._x_solver(x0), p, tcl) - - assert len(model._root_cond_fns()) > 0 - assert model._known_discs(p).size == 0 - - xs, _, _ = eq( - p, - tcl, - h, - x0, - diffrax.Tsit5(), - diffrax.PIDController(**DEFAULT_CONTROLLER_SETTINGS), - optimistix.Newton(atol=1e-8, rtol=1e-8), - diffrax.steady_state_event(rtol=1e-8, atol=1e-8), - diffrax.ODETerm(model._xdot), - model._root_cond_fns(), - model._root_cond_fn, - model._known_discs(p), - 1000, - ) + try: + importer.sbml2jax("time_disc_eq", output_dir=tmp_path) + + module = amici._module_from_path("time_disc_eq", tmp_path / "__init__.py") + model = module.Model() + + p = jnp.array([1.0]) + x0_full = model._x0(0.0, p) + tcl = model._tcl(x0_full, p) + x0 = model._x_solver(x0_full) + h = model._initialise_heaviside_variables(0.0, model._x_solver(x0), p, tcl) + + assert len(model._root_cond_fns()) > 0 + assert model._known_discs(p).size == 0 + + xs, _, _ = eq( + p, + tcl, + h, + x0, + diffrax.Tsit5(), + diffrax.PIDController(**DEFAULT_CONTROLLER_SETTINGS), + optimistix.Newton(atol=1e-8, rtol=1e-8), + diffrax.steady_state_event(rtol=1e-8, atol=1e-8), + diffrax.ODETerm(model._xdot), + model._root_cond_fns(), + model._root_cond_fn, + model._delta_x, + model._known_discs(p), + 1000, + ) + + assert_allclose(xs[0], 0.0, atol=1e-2) - assert_allclose(xs[0], 0.0, atol=1e-2) + except NotImplementedError as err: + if "The JAX backend does not support" in str(err): + pytest.skip(str(err)) + raise err diff --git a/tests/benchmark_models/test_petab_benchmark_jax.py b/tests/benchmark_models/test_petab_benchmark_jax.py index 2511e65122..52d0d053ab 100644 --- a/tests/benchmark_models/test_petab_benchmark_jax.py +++ b/tests/benchmark_models/test_petab_benchmark_jax.py @@ -23,6 +23,8 @@ settings, ) +import diffrax + jax.config.update("jax_enable_x64", True) @@ -35,9 +37,11 @@ def test_jax_llh(benchmark_problem): problem_id, flat_petab_problem, petab_problem, amici_model = ( benchmark_problem ) - if problem_id == "Smith_BMCSystBiol2013": + + to_skip = ["Smith_BMCSystBiol2013", "Oliveira_NatCommun2021", "SalazarCavazos_MBoC2020"] + if problem_id in to_skip: pytest.skip( - "Skipping Smith_BMCSystBiol2013 due to non-supported events in JAX." + f"Skipping {problem_id} due to non-supported events in JAX." ) amici_solver = amici_model.create_solver() @@ -102,10 +106,25 @@ def test_jax_llh(benchmark_problem): ) if problem_id in problems_for_gradient_check: + if problem_id == "Weber_BMC2015": + atol = cur_settings.atol_sim + rtol = cur_settings.rtol_sim + max_steps = 2 * 10**5 + else: + atol = 1e-8 + rtol = 1e-8 + max_steps = 1024 beartype(run_simulations)(jax_problem) (llh_jax, _), sllh_jax = eqx.filter_value_and_grad( run_simulations, has_aux=True - )(jax_problem) + )( + jax_problem, + max_steps=max_steps, + controller=diffrax.PIDController( + atol=atol, + rtol=rtol, + ) + ) else: llh_jax, _ = beartype(run_simulations)(jax_problem) diff --git a/tests/sbml/testSBMLSuiteJax.py b/tests/sbml/testSBMLSuiteJax.py index 2ddb36820e..772c881e49 100644 --- a/tests/sbml/testSBMLSuiteJax.py +++ b/tests/sbml/testSBMLSuiteJax.py @@ -51,7 +51,7 @@ def get_expression_ids(self): def compile_model_jax(sbml_dir: Path, test_id: str, model_dir: Path): model_dir.mkdir(parents=True, exist_ok=True) sbml_file = find_model_file(sbml_dir, test_id) - sbml_importer = amici.SbmlImporter(sbml_file) + sbml_importer = amici.SbmlImporter(sbml_file, jax=True) model_name = f"SBMLTest{test_id}_jax" sbml_importer.sbml2jax(model_name, output_dir=model_dir) model_module = amici.import_model_module(model_dir.name, model_dir.parent) @@ -159,6 +159,9 @@ def test_sbml_testsuite_case_jax( 276, 277, 279, + 356, + 357, + 752, 1148, 1159, 1160,