Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 111 additions & 53 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,8 +523,7 @@ def _sigmays(
in_axes=(0, 0, None, None, 0, 0, 0, 0),
)(ts, xs, p, tcl, hs, iys, ops, nps)

@eqx.filter_jit
def simulate_condition(
def simulate_condition_unjitted(
self,
p: jt.Float[jt.Array, "np"] | None,
ts_dyn: jt.Float[jt.Array, "nt_dyn"],
Expand All @@ -549,58 +548,11 @@ def simulate_condition(
init_override_mask: jt.Bool[jt.Array, "*nx"] = jnp.array([]),
ts_mask: jt.Bool[jt.Array, "nt"] = jnp.array([]),
ret: ReturnValue = ReturnValue.llh,
) -> tuple[jt.Float[jt.Array, "nt *nx"] | jnp.float_, dict]:
r"""
Simulate a condition.
) -> tuple[jt.Float[jt.Array, "*nt"], dict]:
"""
Unjitted version of simulate_condition.

:param p:
parameters for simulation ordered according to ids in :ivar parameter_ids:. If ``None``,
the values stored in :attr:`parameters` are used.
:param ts_dyn:
time points for dynamic simulation. Sorted in monotonically increasing order but duplicate time points are
allowed to facilitate the evaluation of multiple observables at specific time points.
:param ts_posteq:
time points for post-equilibration. Usually valued \Infty, but needs to be shaped according to
the number of observables that are evaluated after post-equilibration.
:param my:
observed data
:param iys:
indices of the observables according to ordering in :ivar observable_ids:
:param iy_trafos:
indices of transformations for observables
:param ops:
observable parameters
:param nps:
noise parameters
:param solver:
ODE solver
:param controller:
step size controller
:param adjoint:
adjoint method. Recommended values are `diffrax.DirectAdjoint()` for jax.jacfwd (with vector-valued
outputs) and `diffrax.RecursiveCheckpointAdjoint()` for jax.grad (for scalar-valued outputs).
:param steady_state_event:
event function for steady state. See :func:`diffrax.steady_state_event` for details.
:param max_steps:
maximum number of solver steps
:param x_preeq:
initial state vector for pre-equilibration. If not provided, the initial state vector is computed using
:meth:`_x0`.
:param mask_reinit:
mask for re-initialization. If `True`, the corresponding state variable is re-initialized.
:param x_reinit:
re-initialized state vector. If not provided, the state vector is not re-initialized.
:param init_override:
override model input e.g. with neural net outputs. If not provided, the inputs are not overridden.
:param init_override_mask:
mask for input override. If `True`, the corresponding input is replaced with the corresponding value from `init_override`.
:param ts_mask:
mask to remove (padded) time points. If `True`, the corresponding time point is used for the evaluation of
the output. Only applied if ret is ReturnValue.llh, ReturnValue.nllhs, ReturnValue.res, or ReturnValue.chi2.
:param ret:
which output to return. See :class:`ReturnValue` for available options.
:return:
output according to `ret` and general results/statistics
See :meth:`simulate_condition` for full documentation.
"""
t0 = 0.0
if p is None:
Expand Down Expand Up @@ -736,6 +688,112 @@ def simulate_condition(

return output, stats

@eqx.filter_jit
def simulate_condition(
self,
p: jt.Float[jt.Array, "np"] | None,
ts_dyn: jt.Float[jt.Array, "nt_dyn"],
ts_posteq: jt.Float[jt.Array, "nt_posteq"],
my: jt.Float[jt.Array, "nt"],
iys: jt.Int[jt.Array, "nt"],
iy_trafos: jt.Int[jt.Array, "nt"],
ops: jt.Float[jt.Array, "nt *nop"],
nps: jt.Float[jt.Array, "nt *nnp"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
root_finder: AbstractRootFinder,
adjoint: diffrax.AbstractAdjoint,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: int | jnp.int_,
x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]),
mask_reinit: jt.Bool[jt.Array, "*nx"] = jnp.array([]),
x_reinit: jt.Float[jt.Array, "*nx"] = jnp.array([]),
init_override: jt.Float[jt.Array, "*nx"] = jnp.array([]),
init_override_mask: jt.Bool[jt.Array, "*nx"] = jnp.array([]),
ts_mask: jt.Bool[jt.Array, "nt"] = jnp.array([]),
ret: ReturnValue = ReturnValue.llh,
) -> tuple[jt.Float[jt.Array, "*nt"], dict]:
r"""
Simulate a condition (JIT-compiled version).

This is the JIT-compiled version for optimal performance. For runtime type checking
with beartype, use :meth:`simulate_condition_unjitted` instead.

:param p:
parameters for simulation ordered according to ids in :ivar parameter_ids:. If ``None``,
the values stored in :attr:`parameters` are used.
:param ts_dyn:
time points for dynamic simulation. Sorted in monotonically increasing order but duplicate time points are
allowed to facilitate the evaluation of multiple observables at specific time points.
:param ts_posteq:
time points for post-equilibration. Usually valued \Infty, but needs to be shaped according to
the number of observables that are evaluated after post-equilibration.
:param my:
observed data
:param iys:
indices of the observables according to ordering in :ivar observable_ids:
:param iy_trafos:
indices of transformations for observables
:param ops:
observable parameters
:param nps:
noise parameters
:param solver:
ODE solver
:param controller:
step size controller
:param adjoint:
adjoint method. Recommended values are `diffrax.DirectAdjoint()` for jax.jacfwd (with vector-valued
outputs) and `diffrax.RecursiveCheckpointAdjoint()` for jax.grad (for scalar-valued outputs).
:param steady_state_event:
event function for steady state. See :func:`diffrax.steady_state_event` for details.
:param max_steps:
maximum number of solver steps
:param x_preeq:
initial state vector for pre-equilibration. If not provided, the initial state vector is computed using
:meth:`_x0`.
:param mask_reinit:
mask for re-initialization. If `True`, the corresponding state variable is re-initialized.
:param x_reinit:
re-initialized state vector. If not provided, the state vector is not re-initialized.
:param init_override:
override model input e.g. with neural net outputs. If not provided, the inputs are not overridden.
:param init_override_mask:
mask for input override. If `True`, the corresponding input is replaced with the corresponding value from `init_override`.
:param ts_mask:
mask to remove (padded) time points. If `True`, the corresponding time point is used for the evaluation of
the output. Only applied if ret is ReturnValue.llh, ReturnValue.nllhs, ReturnValue.res, or ReturnValue.chi2.
:param ret:
which output to return. See :class:`ReturnValue` for available options.
:return:
output according to `ret` and general results/statistics
"""
return self.simulate_condition_unjitted(
p,
ts_dyn,
ts_posteq,
my,
iys,
iy_trafos,
ops,
nps,
solver,
controller,
root_finder,
adjoint,
steady_state_event,
max_steps,
x_preeq,
mask_reinit,
x_reinit,
init_override,
init_override_mask,
ts_mask,
ret,
)

@eqx.filter_jit
def preequilibrate_condition(
self,
Expand Down
4 changes: 3 additions & 1 deletion python/tests/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,9 @@ def check_fields_jax(
"steady_state_event": diffrax.steady_state_event(),
"max_steps": 2**8, # max_steps
}
fun = beartype(jax_model.simulate_condition)
# Use beartype-wrapped unjitted version for type checking
# (beartype cannot introspect jitted functions, so we wrap the unjitted version)
fun = beartype(jax_model.simulate_condition_unjitted)

for output in ["llh", "x0", "x", "y", "res"]:
okwargs = kwargs | {
Expand Down
Loading