Skip to content

Commit 83e1887

Browse files
FFroehlichdweindl
andauthored
Track discontinuities in JAX & fix gradient for models with parameter dependent roots (#2815)
* add tracking of known roots * respect private APIs * pin diffrax * Update test_sbml_semantic_test_suite_jax.yml * Update python/sdist/amici/jax/ode_export.py Co-authored-by: Daniel Weindl <[email protected]> --------- Co-authored-by: Daniel Weindl <[email protected]>
1 parent 33df902 commit 83e1887

File tree

8 files changed

+124
-6
lines changed

8 files changed

+124
-6
lines changed

.github/workflows/test_benchmark_collection_models.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ jobs:
134134
&& python3 -m pip install -U sympy \
135135
&& python3 -m pip install git+https://github.com/ICB-DCM/fiddy.git
136136
137+
- run: pip uninstall -y diffrax && pip install git+https://github.com/patrick-kidger/diffrax # TODO FIXME https://github.com/patrick-kidger/diffrax/issues/654
138+
137139
- name: Download benchmark collection
138140
run: |
139141
pip install git+https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab.git@master#subdirectory=src/python

.github/workflows/test_sbml_semantic_test_suite_jax.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ jobs:
4242
uses: ./.github/actions/install-apt-dependencies
4343

4444
- run: AMICI_PARALLEL_COMPILE="" ./scripts/installAmiciSource.sh
45+
- run: source ./venv/bin/activate && pip uninstall -y diffrax && pip install git+https://github.com/patrick-kidger/diffrax # TODO FIXME https://github.com/patrick-kidger/diffrax/issues/654
4546
- run: ./scripts/run-SBMLTestsuite.sh --jax ${{ matrix.cases }}
4647

4748
- name: "Upload artifact: SBML semantic test suite results"

python/sdist/amici/de_model.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2362,3 +2362,57 @@ def _process_heavisides(
23622362
dxdt = dxdt.subs(heaviside_sympy, heaviside_amici)
23632363

23642364
return dxdt
2365+
2366+
def get_explicit_roots(self) -> set[sp.Expr]:
2367+
"""
2368+
Returns explicit formulas for all discontinuities (events)
2369+
that can be precomputed
2370+
2371+
:return:
2372+
set of symbolic roots
2373+
"""
2374+
return {root for e in self._events for root in e.get_trigger_times()}
2375+
2376+
def get_implicit_roots(self) -> set[sp.Expr]:
2377+
"""
2378+
Returns implicit equations for all discontinuities (events)
2379+
that have to be located via rootfinding
2380+
2381+
:return:
2382+
set of symbolic roots
2383+
"""
2384+
return {
2385+
e.get_val()
2386+
for e in self._events
2387+
if not e.has_explicit_trigger_times()
2388+
}
2389+
2390+
def has_algebraic_states(self) -> bool:
2391+
"""
2392+
Checks whether the model has algebraic states
2393+
2394+
:return:
2395+
boolean indicating if algebraic states are present
2396+
"""
2397+
return len(self._algebraic_states) > 0
2398+
2399+
def has_event_assignments(self) -> bool:
2400+
"""
2401+
Checks whether the model has event assignments
2402+
2403+
:return:
2404+
boolean indicating if event assignments are present
2405+
"""
2406+
return any(event.updates_state for event in self._events)
2407+
2408+
def has_parameter_dependent_implicit_roots(self) -> bool:
2409+
"""
2410+
Checks whether the model has events with parameter-dependent implicit roots
2411+
2412+
:return:
2413+
boolean indicating if parameter-dependent implicit roots are present
2414+
"""
2415+
return any(
2416+
self.sym("p").has(root.free_symbols)
2417+
for root in self.get_implicit_roots()
2418+
)

python/sdist/amici/de_model_components.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,22 @@ def get_trigger_time(self) -> sp.Float:
829829
)
830830
return self._t_root[0]
831831

832+
def has_explicit_trigger_times(self) -> bool:
833+
"""Check whether the event has explicit trigger times.
834+
835+
Explicit trigger times do not require root finding to determine
836+
the time points at which the event triggers.
837+
"""
838+
return len(self._t_root) > 0
839+
840+
def get_trigger_times(self) -> set[sp.Expr]:
841+
"""Get the time points at which the event triggers.
842+
843+
Returns a set of expressions, which may contain multiple time points
844+
for events that trigger at multiple time points.
845+
"""
846+
return set(self._t_root)
847+
832848
@property
833849
def uses_values_from_trigger_time(self) -> bool:
834850
"""Whether the event assignment is evaluated using the state from

python/sdist/amici/jax/jax.template.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# ruff: noqa: F401, F821, F841
22
import jax.numpy as jnp
3+
import jaxtyping as jt
34
from interpax import interp1d
45
from pathlib import Path
56
from jax.numpy import inf as oo
@@ -99,6 +100,11 @@ def _nllh(self, t, x, p, tcl, my, iy, op, np):
99100

100101
return TPL_JY_RET.at[iy].get()
101102

103+
def _known_discs(self, p):
104+
TPL_P_SYMS = p
105+
106+
return TPL_ROOTS
107+
102108
@property
103109
def observable_ids(self):
104110
return TPL_Y_IDS

python/sdist/amici/jax/model.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,20 @@ def _nllh(
238238
"""
239239
...
240240

241+
@abstractmethod
242+
def _known_discs(
243+
self, p: jt.Float[jt.Array, "np"]
244+
) -> jt.Float[jt.Array, "ndiscs"]:
245+
"""
246+
Compute the known discontinuity points of the ODE system.
247+
248+
:param p:
249+
parameters
250+
:return:
251+
known discontinuity points in the ODE system
252+
"""
253+
...
254+
241255
@property
242256
@abstractmethod
243257
def state_ids(self) -> list[str]:
@@ -319,7 +333,9 @@ def _eq(
319333
t1=jnp.inf,
320334
dt0=None,
321335
y0=x0,
322-
stepsize_controller=controller,
336+
stepsize_controller=self._get_clipped_stepsize_controller(
337+
p, controller
338+
),
323339
max_steps=max_steps,
324340
adjoint=diffrax.DirectAdjoint(),
325341
event=diffrax.Event(
@@ -378,7 +394,9 @@ def _solve(
378394
t1=ts[-1],
379395
dt0=None,
380396
y0=x0,
381-
stepsize_controller=controller,
397+
stepsize_controller=self._get_clipped_stepsize_controller(
398+
p, controller
399+
),
382400
max_steps=max_steps,
383401
adjoint=adjoint,
384402
saveat=diffrax.SaveAt(ts=ts),
@@ -747,6 +765,19 @@ def preequilibrate_condition(
747765

748766
return self._x_rdata(current_x, tcl), dict(stats_preeq=stats_preeq)
749767

768+
def _get_clipped_stepsize_controller(
769+
self,
770+
p: jt.Float[jt.Array, "np"],
771+
controller: diffrax.AbstractStepSizeController,
772+
) -> diffrax.AbstractStepSizeController:
773+
if not self._known_discs(p).size:
774+
return controller
775+
776+
return diffrax.ClipStepSizeController(
777+
controller,
778+
jump_ts=self._known_discs(p),
779+
)
780+
750781

751782
def safe_log(x: jnp.float_) -> jnp.float_:
752783
"""

python/sdist/amici/jax/ode_export.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from amici.jax.jaxcodeprinter import AmiciJaxCodePrinter, _jnp_array_str
2626
from amici.jax.model import JAXModel
2727
from amici.de_model import DEModel
28+
2829
from amici.de_export import is_valid_identifier
2930
from amici.import_utils import (
3031
strip_pysb,
@@ -142,14 +143,18 @@ def __init__(
142143
"""
143144
set_log_level(logger, verbose)
144145

145-
if any(event.updates_state for event in ode_model._events):
146+
if ode_model.has_event_assignments():
146147
raise NotImplementedError(
147148
"The JAX backend does not support models with event assignments."
148149
)
149150

150-
if ode_model._algebraic_equations:
151+
if ode_model.has_algebraic_states():
152+
raise NotImplementedError(
153+
"The JAX backend does not support models with algebraic states."
154+
)
155+
if ode_model.has_parameter_dependent_implicit_roots():
151156
raise NotImplementedError(
152-
"The JAX backend does not support models with algebraic equations."
157+
"The JAX backend does not support models with parameter dependent implicit event triggers."
153158
)
154159

155160
self.verbose: bool = logger.getEffectiveLevel() <= logging.DEBUG
@@ -243,6 +248,9 @@ def _generate_jax_code(self) -> None:
243248
# tuple of variable names (ids as they are unique)
244249
**_jax_variable_ids(self.model, ("p", "k", "y", "w", "x_rdata")),
245250
"P_VALUES": _jnp_array_str(self.model.val("p")),
251+
"ROOTS": _jnp_array_str(
252+
{root for e in self.model._events for root in e.get_trigger_times()}
253+
),
246254
**{
247255
"MODEL_NAME": self.model_name,
248256
# keep track of the API version that the model was generated with so we

tests/benchmark_models/test_petab_benchmark_jax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_jax_llh(benchmark_problem):
5757
if problem_id in problems_for_gradient_check:
5858
point = flat_petab_problem.x_nominal_free_scaled
5959
for _ in range(20):
60-
amici_solver.setSensitivityMethod(amici.SensitivityMethod.adjoint)
60+
amici_solver.setSensitivityMethod(amici.SensitivityMethod.forward)
6161
amici_solver.setSensitivityOrder(amici.SensitivityOrder.first)
6262
amici_model.setSteadyStateSensitivityMode(
6363
cur_settings.ss_sensitivity_mode

0 commit comments

Comments
 (0)