Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
6cea00b
event assignments jax - sbml cases 348 - 404
BSnelling Dec 1, 2025
46ff884
fix up sbml test cases - not implemented priority, update t_eps, fix …
BSnelling Dec 5, 2025
728e828
initialValue False not implemented
BSnelling Dec 5, 2025
20495d2
try fix other test cases
BSnelling Dec 5, 2025
1533f8f
Matrix only for JAX event assignments
BSnelling Dec 5, 2025
67c2eb0
params only in explicit triggers - and matrix only in JAX again
BSnelling Dec 5, 2025
533b97e
oops committed breakpoint
BSnelling Dec 5, 2025
6b53bb4
fix delta variables in deltax
BSnelling Dec 5, 2025
af08ed3
new param _delta_x missing in solve calls
BSnelling Dec 5, 2025
b4b3219
try simpler roots direction logic in handle event
BSnelling Dec 8, 2025
0545a07
try not logic in handle event
BSnelling Dec 8, 2025
06a208b
looking for initialValue test cases
BSnelling Dec 8, 2025
8398a38
add h = 0 check to handle event
BSnelling Dec 9, 2025
9aa9866
do not update h pre-solve
BSnelling Dec 9, 2025
523cf20
handle_t0_event
BSnelling Dec 9, 2025
7f5fdab
reinstate time skip (hack diffrax bug?)
BSnelling Dec 9, 2025
3b9471e
Update python/sdist/amici/jax/_simulation.py
BSnelling Dec 10, 2025
b93d884
Revert "Update python/sdist/amici/jax/_simulation.py"
BSnelling Dec 10, 2025
a4be718
rm clip controller
BSnelling Dec 10, 2025
0ab7c68
handle t0 event near zero
BSnelling Dec 10, 2025
4a47d3b
skip non-time dependent event assignment cases
BSnelling Dec 10, 2025
bc51bb4
fix sbml _symbols
BSnelling Dec 11, 2025
449ed78
skip some more tests - NotImplemented discs
BSnelling Dec 11, 2025
e9f47e3
update implicit check and skip SalazarCavazos_MBoC2020 benchmark
BSnelling Dec 12, 2025
3618db8
empty set is not None
BSnelling Dec 12, 2025
900eaef
update solver settings for jax benchmarks
BSnelling Dec 12, 2025
51a7e9c
keep Weber settings specific
BSnelling Dec 12, 2025
d49572e
review comments - remove x_old usage and add TODO/FIXMEs
BSnelling Dec 15, 2025
c4a4166
Merge branch 'main' into bes/jax_event_assignments
BSnelling Dec 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions python/sdist/amici/_symbolic/de_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,8 @@ def _generate_symbol(self, name: str) -> None:
]
)
return
elif name == "deltax":
length = sp.Matrix(self.eq(name)).shape[0]
else:
length = len(self.eq(name))
self._syms[name] = sp.Matrix(
Expand Down Expand Up @@ -2623,30 +2625,30 @@ def _process_hybridization(self, hybridization: dict) -> None:
if added_expressions:
self.toposort_expressions()

def get_explicit_roots(self) -> set[sp.Expr]:
def get_explicit_roots(self) -> list[sp.Expr]:
"""
Returns explicit formulas for all discontinuities (events)
that can be precomputed

:return:
set of symbolic roots
list of symbolic roots
"""
return {root for e in self._events for root in e.get_trigger_times()}
return [root for e in self._events for root in e.get_trigger_times()]

def get_implicit_roots(self) -> set[sp.Expr]:
def get_implicit_roots(self) -> list[sp.Expr]:
"""
Returns implicit equations for all discontinuities (events)
that have to be located via rootfinding

:return:
set of symbolic roots
list of symbolic roots
"""
return {
return [
e.get_val()
for e in self._events
if not e.has_explicit_trigger_times()
}

]
def has_algebraic_states(self) -> bool:
"""
Checks whether the model has algebraic states
Expand All @@ -2664,6 +2666,24 @@ def has_event_assignments(self) -> bool:
boolean indicating if event assignments are present
"""
return any(event.updates_state for event in self._events)

def has_priority_events(self) -> bool:
"""
Checks whether the model has events with priorities defined

:return:
boolean indicating if priority events are present
"""
return any(event.get_priority() is not None for event in self._events)

def has_implicit_event_assignments(self) -> bool:
"""
Checks whether the model has event assignments with implicit triggers

:return:
boolean indicating if event assignments with implicit triggers are present
"""
return any(event.updates_state and not event.has_explicit_trigger_times({}) for event in self._events)

def toposort_expressions(
self, reorder: bool = True
Expand Down
2 changes: 1 addition & 1 deletion python/sdist/amici/importers/petab/_petab_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def _do_import_sbml(self):

show_model_info(self.petab_problem.model.sbml_model)
sbml_importer = amici.SbmlImporter(
self.petab_problem.model.sbml_model,
self.petab_problem.model.sbml_model, jax=self._jax
)

self._check_placeholders()
Expand Down
1 change: 1 addition & 0 deletions python/sdist/amici/importers/petab/v1/sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def import_model_sbml(
sbml_importer = amici.SbmlImporter(
sbml_model,
discard_annotations=discard_sbml_annotations,
jax=jax,
)
sbml_model = sbml_importer.sbml_model

Expand Down
17 changes: 17 additions & 0 deletions python/sdist/amici/importers/sbml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(
show_sbml_warnings: bool = False,
from_file: bool = True,
discard_annotations: bool = False,
jax: bool = False,
) -> None:
"""
Initialize.
Expand Down Expand Up @@ -188,6 +189,7 @@ def __init__(
ignore_units=True,
evaluate=True,
)
self.jax = jax

@log_execution_time("loading SBML", logger)
def _process_document(self) -> None:
Expand Down Expand Up @@ -1893,6 +1895,21 @@ def _process_events(self) -> None:
"priority": self._sympify(event.getPriority()),
}

if self.jax:
# Add a negative event for JAX models to handle
# TODO: remove once condition function directions can be
# traced through diffrax solve
neg_event_id = event_id + "_negative"
neg_event_sym = sp.Symbol(neg_event_id)
self._symbols[SymbolId.EVENT][neg_event_sym] = {
"name": neg_event_id,
"value": -trigger,
"assignments": None,
"initial_value": not initial_value,
"use_values_from_trigger_time": use_trig_val,
"priority": self._sympify(event.getPriority()),
}

@log_execution_time("processing observation model", logger)
def _process_observation_model(
self,
Expand Down
91 changes: 52 additions & 39 deletions python/sdist/amici/jax/_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def eq(
term: diffrax.ODETerm,
root_cond_fns: list[Callable],
root_cond_fn: Callable,
delta_x: Callable,
known_discs: jt.Float[jt.Array, "*nediscs"],
max_steps: jnp.int_,
) -> tuple[jt.Float[jt.Array, "nxs"], jt.Float[jt.Array, "ne"], dict]:
Expand Down Expand Up @@ -61,6 +62,8 @@ def eq(
list of individual root condition functions for discontinuities
:param root_cond_fn:
root condition function for all discontinuities
:param delta_x:
function to compute state changes at events
:param known_discs:
known discontinuities, used to clip the step size controller
:param max_steps:
Expand All @@ -86,7 +89,6 @@ def eq(
[None],
diffrax.SaveAt(t1=True),
term,
known_discs,
dict(**STARTING_STATS),
)
y1 = jnp.where(
Expand Down Expand Up @@ -123,7 +125,6 @@ def body_fn(carry):
[None] + [True] * len(root_cond_fns),
diffrax.SaveAt(t1=True),
term,
known_discs,
stats,
)
y0_next = jnp.where(
Expand All @@ -136,19 +137,16 @@ def body_fn(carry):
)
t0_next = jnp.where(jnp.isfinite(sol.ts), sol.ts, -jnp.inf).max()

y0_next, t0_next, h_next, stats = _handle_event(
y0_next, h_next, stats = _handle_event(
t0_next,
jnp.inf,
y0_next,
p,
tcl,
h,
solver,
controller,
root_finder,
diffrax.DirectAdjoint(),
term,
root_cond_fn,
delta_x,
stats,
)

Expand Down Expand Up @@ -186,6 +184,7 @@ def solve(
term: diffrax.ODETerm,
root_cond_fns: list[Callable],
root_cond_fn: Callable,
delta_x: Callable,
known_discs: jt.Float[jt.Array, "*nediscs"],
) -> tuple[jt.Float[jt.Array, "nt nxs"], jt.Float[jt.Array, "nt ne"], dict]:
"""
Expand Down Expand Up @@ -213,6 +212,8 @@ def solve(
list of individual root condition functions for discontinuities
:param root_cond_fn:
root condition function for all discontinuities
:param delta_x:
function to compute state changes at events
:param known_discs:
known discontinuities, used to clip the step size controller
:return:
Expand All @@ -237,7 +238,6 @@ def solve(
[],
diffrax.SaveAt(ts=ts),
term,
known_discs,
dict(**STARTING_STATS),
)
return sol.ys, jnp.repeat(h[None, :], sol.ys.shape[0]), stats
Expand All @@ -254,7 +254,7 @@ def cond_fn(carry):

def body_fn(carry):
ys, t_start, y0, hs, h, stats = carry
sol, idx, stats = _run_segment(
sol, _, stats = _run_segment(
t_start,
ts[-1],
y0,
Expand All @@ -277,7 +277,6 @@ def body_fn(carry):
]
),
term,
known_discs,
stats,
)
# update the solution for all timepoints in the simulated segment
Expand All @@ -291,28 +290,22 @@ def body_fn(carry):
y0_next = sol.ys[1][
-1
] # next initial state is the last state of the current segment
ts_next = jnp.where(
ts > t0_next, ts, ts[-1]
).min() # timepoint of next datapoint, don't step over that

y0_next, t0_next, h_next, stats = _handle_event(
y0_next, h_next, stats = _handle_event(
t0_next,
ts_next,
y0_next,
p,
tcl,
h,
solver,
controller,
root_finder,
adjoint,
term,
root_cond_fn,
delta_x,
stats,
)

was_event = jnp.isin(ts, sol.ts[1])
hs = jnp.where(was_event[:, None], h_next[None, :], hs)
after_event = sol.ts[1] < ts
hs = jnp.where(after_event[:, None], h_next[None, :], hs)

return ys, t0_next, y0_next, hs, h_next, stats

Expand Down Expand Up @@ -351,7 +344,6 @@ def _run_segment(
cond_dirs: list[None | bool],
saveat: diffrax.SaveAt,
term: diffrax.ODETerm,
known_discs: jt.Float[jt.Array, "*nediscs"],
stats: dict,
) -> tuple[diffrax.Solution, int, dict]:
"""Solve a single integration segment and return triggered event index, start time for the next segment,
Expand All @@ -373,16 +365,6 @@ def _run_segment(
else None
)

# manage events with explicit discontinuities
controller = (
diffrax.ClipStepSizeController(
controller,
jump_ts=known_discs,
)
if known_discs.size
else controller
)

sol = diffrax.diffeqsolve(
term,
solver,
Expand Down Expand Up @@ -429,17 +411,14 @@ def _run_segment(

def _handle_event(
t0_next: float,
t_max: float,
y0_next: jt.Float[jt.Array, "nxs"],
p: jt.Float[jt.Array, "np"],
tcl: jt.Float[jt.Array, "ncl"],
h: jt.Float[jt.Array, "ne"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
root_finder: AbstractRootFinder,
adjoint: diffrax.AbstractAdjoint,
term: diffrax.ODETerm,
root_cond_fn: Callable,
delta_x: Callable,
stats: dict,
):
args = (p, tcl, h)
Expand All @@ -457,11 +436,15 @@ def _handle_event(
)
roots_dir = jnp.sign(droot_dt) # direction of the root condition function

h_next = h + jnp.where(
y0_next, h_next = _apply_event_assignments(
roots_found,
roots_dir,
jnp.zeros_like(h),
) # update heaviside variables based on the root condition function
y0_next,
p,
tcl,
h,
delta_x,
)

if os.getenv("JAX_DEBUG") == "1":
jax.debug.print(
Expand All @@ -472,4 +455,34 @@ def _handle_event(
h,
h_next,
)
return y0_next, t0_next, h_next, stats

return y0_next, h_next, stats

def _apply_event_assignments(
roots_found,
roots_dir,
y0_next,
p,
tcl,
h,
delta_x,
):
h_next = jnp.where(
roots_found,
jnp.logical_not(h),
h,
) # update heaviside variables based on the root condition function

mask = jnp.array(
[
(roots_found & (roots_dir > 0.0) & (h == 0.0))
for _ in range(y0_next.shape[0])
]
).T
delx = delta_x(y0_next, p, tcl)
if y0_next.size:
delx = delx.reshape(delx.size // y0_next.shape[0], y0_next.shape[0],)
y0_up = jnp.where(mask, delx, 0.0)
y0_next = y0_next + jnp.sum(y0_up, axis=0)

return y0_next, h_next
Loading
Loading