diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index c2134bf9f6..c8a3f56014 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -523,8 +523,7 @@ def _sigmays( in_axes=(0, 0, None, None, 0, 0, 0, 0), )(ts, xs, p, tcl, hs, iys, ops, nps) - @eqx.filter_jit - def simulate_condition( + def simulate_condition_unjitted( self, p: jt.Float[jt.Array, "np"] | None, ts_dyn: jt.Float[jt.Array, "nt_dyn"], @@ -549,58 +548,11 @@ def simulate_condition( init_override_mask: jt.Bool[jt.Array, "*nx"] = jnp.array([]), ts_mask: jt.Bool[jt.Array, "nt"] = jnp.array([]), ret: ReturnValue = ReturnValue.llh, - ) -> tuple[jt.Float[jt.Array, "nt *nx"] | jnp.float_, dict]: - r""" - Simulate a condition. + ) -> tuple[jt.Float[jt.Array, "*nt"], dict]: + """ + Unjitted version of simulate_condition. - :param p: - parameters for simulation ordered according to ids in :ivar parameter_ids:. If ``None``, - the values stored in :attr:`parameters` are used. - :param ts_dyn: - time points for dynamic simulation. Sorted in monotonically increasing order but duplicate time points are - allowed to facilitate the evaluation of multiple observables at specific time points. - :param ts_posteq: - time points for post-equilibration. Usually valued \Infty, but needs to be shaped according to - the number of observables that are evaluated after post-equilibration. - :param my: - observed data - :param iys: - indices of the observables according to ordering in :ivar observable_ids: - :param iy_trafos: - indices of transformations for observables - :param ops: - observable parameters - :param nps: - noise parameters - :param solver: - ODE solver - :param controller: - step size controller - :param adjoint: - adjoint method. Recommended values are `diffrax.DirectAdjoint()` for jax.jacfwd (with vector-valued - outputs) and `diffrax.RecursiveCheckpointAdjoint()` for jax.grad (for scalar-valued outputs). - :param steady_state_event: - event function for steady state. See :func:`diffrax.steady_state_event` for details. - :param max_steps: - maximum number of solver steps - :param x_preeq: - initial state vector for pre-equilibration. If not provided, the initial state vector is computed using - :meth:`_x0`. - :param mask_reinit: - mask for re-initialization. If `True`, the corresponding state variable is re-initialized. - :param x_reinit: - re-initialized state vector. If not provided, the state vector is not re-initialized. - :param init_override: - override model input e.g. with neural net outputs. If not provided, the inputs are not overridden. - :param init_override_mask: - mask for input override. If `True`, the corresponding input is replaced with the corresponding value from `init_override`. - :param ts_mask: - mask to remove (padded) time points. If `True`, the corresponding time point is used for the evaluation of - the output. Only applied if ret is ReturnValue.llh, ReturnValue.nllhs, ReturnValue.res, or ReturnValue.chi2. - :param ret: - which output to return. See :class:`ReturnValue` for available options. - :return: - output according to `ret` and general results/statistics + See :meth:`simulate_condition` for full documentation. """ t0 = 0.0 if p is None: @@ -736,6 +688,112 @@ def simulate_condition( return output, stats + @eqx.filter_jit + def simulate_condition( + self, + p: jt.Float[jt.Array, "np"] | None, + ts_dyn: jt.Float[jt.Array, "nt_dyn"], + ts_posteq: jt.Float[jt.Array, "nt_posteq"], + my: jt.Float[jt.Array, "nt"], + iys: jt.Int[jt.Array, "nt"], + iy_trafos: jt.Int[jt.Array, "nt"], + ops: jt.Float[jt.Array, "nt *nop"], + nps: jt.Float[jt.Array, "nt *nnp"], + solver: diffrax.AbstractSolver, + controller: diffrax.AbstractStepSizeController, + root_finder: AbstractRootFinder, + adjoint: diffrax.AbstractAdjoint, + steady_state_event: Callable[ + ..., diffrax._custom_types.BoolScalarLike + ], + max_steps: int | jnp.int_, + x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), + mask_reinit: jt.Bool[jt.Array, "*nx"] = jnp.array([]), + x_reinit: jt.Float[jt.Array, "*nx"] = jnp.array([]), + init_override: jt.Float[jt.Array, "*nx"] = jnp.array([]), + init_override_mask: jt.Bool[jt.Array, "*nx"] = jnp.array([]), + ts_mask: jt.Bool[jt.Array, "nt"] = jnp.array([]), + ret: ReturnValue = ReturnValue.llh, + ) -> tuple[jt.Float[jt.Array, "*nt"], dict]: + r""" + Simulate a condition (JIT-compiled version). + + This is the JIT-compiled version for optimal performance. For runtime type checking + with beartype, use :meth:`simulate_condition_unjitted` instead. + + :param p: + parameters for simulation ordered according to ids in :ivar parameter_ids:. If ``None``, + the values stored in :attr:`parameters` are used. + :param ts_dyn: + time points for dynamic simulation. Sorted in monotonically increasing order but duplicate time points are + allowed to facilitate the evaluation of multiple observables at specific time points. + :param ts_posteq: + time points for post-equilibration. Usually valued \Infty, but needs to be shaped according to + the number of observables that are evaluated after post-equilibration. + :param my: + observed data + :param iys: + indices of the observables according to ordering in :ivar observable_ids: + :param iy_trafos: + indices of transformations for observables + :param ops: + observable parameters + :param nps: + noise parameters + :param solver: + ODE solver + :param controller: + step size controller + :param adjoint: + adjoint method. Recommended values are `diffrax.DirectAdjoint()` for jax.jacfwd (with vector-valued + outputs) and `diffrax.RecursiveCheckpointAdjoint()` for jax.grad (for scalar-valued outputs). + :param steady_state_event: + event function for steady state. See :func:`diffrax.steady_state_event` for details. + :param max_steps: + maximum number of solver steps + :param x_preeq: + initial state vector for pre-equilibration. If not provided, the initial state vector is computed using + :meth:`_x0`. + :param mask_reinit: + mask for re-initialization. If `True`, the corresponding state variable is re-initialized. + :param x_reinit: + re-initialized state vector. If not provided, the state vector is not re-initialized. + :param init_override: + override model input e.g. with neural net outputs. If not provided, the inputs are not overridden. + :param init_override_mask: + mask for input override. If `True`, the corresponding input is replaced with the corresponding value from `init_override`. + :param ts_mask: + mask to remove (padded) time points. If `True`, the corresponding time point is used for the evaluation of + the output. Only applied if ret is ReturnValue.llh, ReturnValue.nllhs, ReturnValue.res, or ReturnValue.chi2. + :param ret: + which output to return. See :class:`ReturnValue` for available options. + :return: + output according to `ret` and general results/statistics + """ + return self.simulate_condition_unjitted( + p, + ts_dyn, + ts_posteq, + my, + iys, + iy_trafos, + ops, + nps, + solver, + controller, + root_finder, + adjoint, + steady_state_event, + max_steps, + x_preeq, + mask_reinit, + x_reinit, + init_override, + init_override_mask, + ts_mask, + ret, + ) + @eqx.filter_jit def preequilibrate_condition( self, diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index 37073f6701..f2354b9bff 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -206,7 +206,9 @@ def check_fields_jax( "steady_state_event": diffrax.steady_state_event(), "max_steps": 2**8, # max_steps } - fun = beartype(jax_model.simulate_condition) + # Use beartype-wrapped unjitted version for type checking + # (beartype cannot introspect jitted functions, so we wrap the unjitted version) + fun = beartype(jax_model.simulate_condition_unjitted) for output in ["llh", "x0", "x", "y", "res"]: okwargs = kwargs | {