Skip to content

Commit a5b708c

Browse files
authored
Add JAX SBML test suite (#2795)
* Add JAX SBML test suite * Refine JAX SBML tests * Move SBML tests to subfolder * Remove unused test package markers * Refactor SBML tests directory * Move SBML tests back to tests/sbml * Consolidate SBML test scripts with --jax option * smaller fixes, add error for jax models with event assignemnts * enable x64 precision, add error for algebraic systems * fix t in x0 and expressionIds, specific lower tolerances * Handle default parameters for JAX models (#2800) * Apply suggestions from code review * Update test_sbml_semantic_test_suite_jax.yml * Update testSBMLSuiteJax.py * fix inf/nan * small fixes * reintroduce test matrix * update test readme * move sbml testsuite * Update conftest.py * fix petab benchmark? * Update test_petab_benchmark_jax.py * Update .github/workflows/test_sbml_semantic_test_suite_jax.yml * enable 01395
1 parent a7be07a commit a5b708c

File tree

15 files changed

+400
-54
lines changed

15 files changed

+400
-54
lines changed

.github/workflows/test_sbml_semantic_test_suite.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ on:
1414
- python/sdist/amici/sbml_import.py
1515
- python/sdist/amici/import_utils.py
1616
- scripts/run-SBMLTestsuite.sh
17-
- tests/testSBMLSuite.py
18-
- tests/conftest.py
17+
- tests/sbml/testSBMLSuite.py
18+
- tests/sbml/conftest.py
1919
check_suite:
2020
types: [requested]
2121
workflow_dispatch:
@@ -52,7 +52,7 @@ jobs:
5252
uses: actions/upload-artifact@v4
5353
with:
5454
name: amici-semantic-results-${{ matrix.cases }}
55-
path: tests/amici-semantic-results
55+
path: tests/sbml/amici-semantic-results
5656

5757
- name: Codecov SBMLSuite
5858
if: github.event_name == 'pull_request' || github.repository_owner == 'AMICI-dev'
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
name: SBML JAX
2+
on:
3+
push:
4+
branches:
5+
- develop
6+
- master
7+
- release**
8+
pull_request:
9+
paths:
10+
- .github/workflows/test_sbml_semantic_test_suite_jax.yml
11+
- python/sdist/amici/jax/**
12+
- scripts/run-SBMLTestsuite.sh
13+
- tests/sbml/testSBMLSuiteJax.py
14+
- tests/sbml/conftest.py
15+
check_suite:
16+
types: [requested]
17+
workflow_dispatch:
18+
19+
jobs:
20+
build:
21+
name: SBML Semantic Test Suite JAX
22+
runs-on: ubuntu-24.04
23+
24+
strategy:
25+
fail-fast: false
26+
matrix:
27+
cases: ["1-600", "601-1200", "1200-"]
28+
python-version: [ "3.13" ]
29+
30+
steps:
31+
- name: Set up Python ${{ matrix.python-version }}
32+
uses: actions/setup-python@v5
33+
with:
34+
python-version: ${{ matrix.python-version }}
35+
36+
37+
- uses: actions/checkout@v4
38+
with:
39+
fetch-depth: 1
40+
41+
- name: Install apt dependencies
42+
uses: ./.github/actions/install-apt-dependencies
43+
44+
- run: AMICI_PARALLEL_COMPILE="" ./scripts/installAmiciSource.sh
45+
- run: ./scripts/run-SBMLTestsuite.sh --jax ${{ matrix.cases }}
46+
47+
- name: "Upload artifact: SBML semantic test suite results"
48+
uses: actions/upload-artifact@v4
49+
with:
50+
name: amici-semantic-results-jax-${{ matrix.cases }}
51+
path: tests/sbml/amici-semantic-results-jax
52+
53+
- name: Codecov SBMLSuiteJax
54+
if: github.event_name == 'pull_request' || github.repository_owner == 'AMICI-dev'
55+
uses: codecov/codecov-action@v5
56+
with:
57+
token: ${{ secrets.CODECOV_TOKEN }}
58+
files: coverage_SBMLSuite_jax.xml
59+
flags: sbmlsuite-jax
60+
fail_ci_if_error: true

.gitignore

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,10 @@ tests/test/*
133133
*/tests/dimerization/*
134134
tests/cpp/writeResults.h5
135135
tests/cpp/writeResults.h5.bak
136-
tests/sbml-test-suite/*
137-
tests/sbml-test-suite/
138-
tests/sedml-test-suite/
136+
tests/sbml/sbml-test-suite/*
137+
tests/sbml/sbml-test-suite/
139138
*/sbml-semantic-test-cases/*
140-
tests/SBMLTestModels/
139+
tests/sbml/SBMLTestModels/
141140
tests/benchmark-models/test_bmc
142141
tests/petab_test_suite
143142
petab_test_suite

doc/CI.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,11 @@ To run the SBML Test Suite test cases, the easiest way is:
4343
2. Running `scripts/run-SBMLTestsuite.sh`. This will download the test cases
4444
if necessary and run them all. A subset of test cases can be selected with
4545
an optional argument (e.g. `scripts/run-SBMLTestsuite.sh 1,3-6,8`, to run
46-
cases 1, 3, 4, 5, 6 and 8).
46+
cases 1, 3, 4, 5, 6 and 8). Use `--jax` as the first argument to run the
47+
SBML tests using the JAX backend.
4748

4849
Once the test cases are available locally, for debugging it might be easier
49-
to directly use `pytest` with `tests/testSBMLSuite.py`.
50+
to directly use `pytest` with `tests/sbml/testSBMLSuite.py`.
5051

5152

5253
## Matlab tests (not included in CI pipeline)

python/sdist/amici/jax/jax.template.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import jax.numpy as jnp
33
from interpax import interp1d
44
from pathlib import Path
5+
from jax.numpy import inf as oo
6+
from jax.numpy import nan as nan
57

68
from amici.jax.model import JAXModel, safe_log, safe_div
79

@@ -11,6 +13,7 @@ class JAXModel_TPL_MODEL_NAME(JAXModel):
1113

1214
def __init__(self):
1315
self.jax_py_file = Path(__file__).resolve()
16+
self.parameters = TPL_P_VALUES
1417
super().__init__()
1518

1619
def _xdot(self, t, x, args):
@@ -34,7 +37,7 @@ def _w(self, t, x, p, tcl):
3437

3538
return TPL_W_RET
3639

37-
def _x0(self, p):
40+
def _x0(self, t, p):
3841
TPL_P_SYMS = p
3942

4043
TPL_X0_EQ
@@ -86,6 +89,9 @@ def _sigmay(self, y, p, np):
8689

8790
def _nllh(self, t, x, p, tcl, my, iy, op, np):
8891
y = self._y(t, x, p, tcl, op)
92+
if not y.size:
93+
return jnp.array(0.0)
94+
8995
TPL_Y_SYMS = y
9096
TPL_SIGMAY_SYMS = self._sigmay(y, p, np)
9197

@@ -105,5 +111,9 @@ def state_ids(self):
105111
def parameter_ids(self):
106112
return TPL_P_IDS
107113

114+
@property
115+
def expression_ids(self):
116+
return TPL_W_IDS
117+
108118

109119
Model = JAXModel_TPL_MODEL_NAME

python/sdist/amici/jax/jaxcodeprinter.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88
from sympy.printing.numpy import NumPyPrinter
99

1010

11+
def _jnp_array_str(array) -> str:
12+
elems = ", ".join(str(s) for s in array)
13+
14+
return f"jnp.array([{elems}])"
15+
16+
1117
class AmiciJaxCodePrinter(NumPyPrinter):
1218
"""JAX code printer"""
1319

@@ -36,6 +42,18 @@ def _print_Mul(self, expr: sp.Expr) -> str:
3642
return super()._print_Mul(expr)
3743
return f"safe_div({self.doprint(numer)}, {self.doprint(denom)})"
3844

45+
def _print_Max(self, expr: sp.Expr) -> str:
46+
"""
47+
Print the max function, replacing it with jnp.max.
48+
"""
49+
return f"jnp.max({_jnp_array_str(expr.args)})"
50+
51+
def _print_Min(self, expr: sp.Expr) -> str:
52+
"""
53+
Print the min function, replacing it with jnp.min.
54+
"""
55+
return f"jnp.min({_jnp_array_str(expr.args)})"
56+
3957
def _get_sym_lines(
4058
self,
4159
symbols: sp.Matrix | Iterable[str],

python/sdist/amici/jax/model.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from abc import abstractmethod
66
from pathlib import Path
77
import enum
8+
from dataclasses import field
89

910
import diffrax
1011
import equinox as eqx
@@ -43,9 +44,10 @@ class JAXModel(eqx.Module):
4344
Path to the JAX model file.
4445
"""
4546

46-
MODEL_API_VERSION = "0.0.3"
47+
MODEL_API_VERSION = "0.0.4"
4748
api_version: str
4849
jax_py_file: Path
50+
parameters: jnp.ndarray = field(default_factory=lambda: jnp.array([]))
4951

5052
def __init__(self):
5153
if self.api_version != self.MODEL_API_VERSION:
@@ -93,11 +95,16 @@ def _w(
9395
...
9496

9597
@abstractmethod
96-
def _x0(self, p: jt.Float[jt.Array, "np"]) -> jt.Float[jt.Array, "nx"]:
98+
def _x0(
99+
self, t: jnp.float_, p: jt.Float[jt.Array, "np"]
100+
) -> jt.Float[jt.Array, "nx"]:
97101
"""
98102
Compute the initial state vector.
99103
104+
:param t: initial time point
100105
:param p: parameters
106+
:return:
107+
Initial state vector.
101108
"""
102109
...
103110

@@ -264,6 +271,17 @@ def parameter_ids(self) -> list[str]:
264271
"""
265272
...
266273

274+
@property
275+
@abstractmethod
276+
def expression_ids(self) -> list[str]:
277+
"""
278+
Get the expression ids of the model.
279+
280+
:return:
281+
Expression ids
282+
"""
283+
...
284+
267285
def _eq(
268286
self,
269287
p: jt.Float[jt.Array, "np"],
@@ -496,7 +514,7 @@ def _sigmays(
496514
@eqx.filter_jit
497515
def simulate_condition(
498516
self,
499-
p: jt.Float[jt.Array, "np"],
517+
p: jt.Float[jt.Array, "np"] | None,
500518
ts_dyn: jt.Float[jt.Array, "nt_dyn"],
501519
ts_posteq: jt.Float[jt.Array, "nt_posteq"],
502520
my: jt.Float[jt.Array, "nt"],
@@ -521,7 +539,8 @@ def simulate_condition(
521539
Simulate a condition.
522540
523541
:param p:
524-
parameters for simulation ordered according to ids in :ivar parameter_ids:
542+
parameters for simulation ordered according to ids in :ivar parameter_ids:. If ``None``,
543+
the values stored in :attr:`parameters` are used.
525544
:param ts_dyn:
526545
time points for dynamic simulation. Sorted in monotonically increasing order but duplicate time points are
527546
allowed to facilitate the evaluation of multiple observables at specific time points.
@@ -564,10 +583,13 @@ def simulate_condition(
564583
:return:
565584
output according to `ret` and general results/statistics
566585
"""
586+
if p is None:
587+
p = self.parameters
588+
567589
if x_preeq.shape[0]:
568590
x = x_preeq
569591
else:
570-
x = self._x0(p)
592+
x = self._x0(0.0, p)
571593

572594
if not ts_mask.shape[0]:
573595
ts_mask = jnp.ones_like(my, dtype=jnp.bool_)
@@ -675,7 +697,7 @@ def simulate_condition(
675697
@eqx.filter_jit
676698
def preequilibrate_condition(
677699
self,
678-
p: jt.Float[jt.Array, "np"],
700+
p: jt.Float[jt.Array, "np"] | None,
679701
x_reinit: jt.Float[jt.Array, "*nx"],
680702
mask_reinit: jt.Bool[jt.Array, "*nx"],
681703
solver: diffrax.AbstractSolver,
@@ -689,7 +711,8 @@ def preequilibrate_condition(
689711
Simulate a condition.
690712
691713
:param p:
692-
parameters for simulation ordered according to ids in :ivar parameter_ids:
714+
parameters for simulation ordered according to ids in :ivar parameter_ids:. If ``None``,
715+
the values stored in :attr:`parameters` are used.
693716
:param x_reinit:
694717
re-initialized state vector. If not provided, the state vector is not re-initialized.
695718
:param mask_reinit:
@@ -704,7 +727,10 @@ def preequilibrate_condition(
704727
pre-equilibrated state variables and statistics
705728
"""
706729
# Pre-equilibration
707-
x0 = self._x0(p)
730+
if p is None:
731+
p = self.parameters
732+
733+
x0 = self._x0(0.0, p)
708734
if x_reinit.shape[0]:
709735
x0 = jnp.where(mask_reinit, x_reinit, x0)
710736
tcl = self._tcl(x0, p)

python/sdist/amici/jax/ode_export.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
)
2323

2424
from amici._codegen.template import apply_template
25-
from amici.jax.jaxcodeprinter import AmiciJaxCodePrinter
25+
from amici.jax.jaxcodeprinter import AmiciJaxCodePrinter, _jnp_array_str
2626
from amici.jax.model import JAXModel
2727
from amici.de_model import DEModel
2828
from amici.de_export import is_valid_identifier
@@ -96,12 +96,6 @@ def _jax_variable_ids(model: DEModel, sym_names: tuple[str, ...]) -> dict:
9696
}
9797

9898

99-
def _jnp_array_str(array) -> str:
100-
elems = ", ".join(str(s) for s in array)
101-
102-
return f"jnp.array([{elems}])"
103-
104-
10599
class ODEExporter:
106100
"""
107101
The ODEExporter class generates AMICI jax files for a model as
@@ -148,6 +142,16 @@ def __init__(
148142
"""
149143
set_log_level(logger, verbose)
150144

145+
if any(event.updates_state for event in ode_model._events):
146+
raise NotImplementedError(
147+
"The JAX backend does not support models with event assignments."
148+
)
149+
150+
if ode_model._algebraic_equations:
151+
raise NotImplementedError(
152+
"The JAX backend does not support models with algebraic equations."
153+
)
154+
151155
self.verbose: bool = logger.getEffectiveLevel() <= logging.DEBUG
152156

153157
self.model_path: Path = Path()
@@ -237,7 +241,8 @@ def _generate_jax_code(self) -> None:
237241
# assign named variables from a jax array
238242
**_jax_variable_assignments(self.model, sym_names),
239243
# tuple of variable names (ids as they are unique)
240-
**_jax_variable_ids(self.model, ("p", "k", "y", "x_rdata")),
244+
**_jax_variable_ids(self.model, ("p", "k", "y", "w", "x_rdata")),
245+
"P_VALUES": _jnp_array_str(self.model.val("p")),
241246
**{
242247
"MODEL_NAME": self.model_name,
243248
# keep track of the API version that the model was generated with so we

python/sdist/amici/sbml_import.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,7 +1029,8 @@ def _gather_base_locals(
10291029
if not x_ref.isSetId():
10301030
continue
10311031
if (
1032-
x_ref.isSetStoichiometry()
1032+
hasattr(x_ref, "isSetStoichiometry")
1033+
and x_ref.isSetStoichiometry()
10331034
and not self.is_assignment_rule_target(x_ref)
10341035
):
10351036
value = sp.Float(x_ref.getStoichiometry())
@@ -2986,7 +2987,7 @@ def _get_element_stoichiometry(self, ele: libsbml.SBase) -> sp.Expr:
29862987
if self.is_assignment_rule_target(ele):
29872988
return _get_identifier_symbol(ele)
29882989

2989-
if ele.isSetStoichiometry():
2990+
if hasattr(ele, "isSetStoichiometry") and ele.isSetStoichiometry():
29902991
stoichiometry: float = ele.getStoichiometry()
29912992
return (
29922993
sp.Integer(stoichiometry)

0 commit comments

Comments
 (0)