Skip to content

Commit 4360d5e

Browse files
FFroehlichBSnellingdweindl
authored
Support for PEtab SciML in jax (#2614)
* add jax serialisation * doc * no compilation for jax * bad ruff * Update ExampleJaxPEtab.ipynb * bad ruff * Update ExampleJaxPEtab.ipynb * add nan safe log&divide * some net cases and first ude testcase passing * some more passing ude test * updates * fix merge * remove changes doc * update net_004_alt test * refactor to pytests * fixup merge * fix net test-cases * fixes & remove sciml dependency * fixup, add initial condition support * Update petab.py * Update test_petab_sciml.yml * Update petab.py * Update petab.py * Update petab.py * Update ExampleJaxPEtab.ipynb * ignore test warning * Update test_petab_sciml.yml * add hybridisation support * fix workflow * update testsuite * update after test refactor * fix hybridization * update testsuite * update testsuite, some fixes * Update testsuite * fix #2687 * spec updates * updates for overhauled testsuite * Update nn.py * Update petab_import.py * Getting test_net petab-sciml tests to pass - update sciml testsuite submodule to point at main - fix a eqx LayerNorm deprecation warning - fix string formatting of bool in kwarg - updates to test code driven by updated sciml format * Implementing features for a subset of ude petab_sciml test cases. Excludes: - frozen nn layers - nns in observable formulae * update petab_sciml workflow - on branches and sciml install branch * updates to petab sciml workflow * fix undef local var in jax tests * frozen layers for RHS networks generalise frozen layers to networks across system Use stop_grad instead * implement nns in the observable formula * tidy, refactor, generalise sciml test case implementations * hybridization df in _petab_problem - makes JAXProblem jit-able * update frozen layer/arrays implementation * update jax petab notebook * add h5py to docs deps * fix sbml jax tests * missed rebased imports * codecov maybe * codecov - update cov file name * codecov - specify cov path * enable zero params case * doc build forward type definition workaround * simplify array input processing * bump versions, fixup notebook * safety around nn_output_ids * reinstate test * fix imports from amici.jax.nn * skip petab tests with mapping df * Apply suggestion from @dweindl Co-authored-by: Daniel Weindl <dweindl@users.noreply.github.com> * Update python/sdist/amici/petab/petab_import.py Co-authored-by: Daniel Weindl <dweindl@users.noreply.github.com> * Update python/sdist/amici/petab/petab_import.py Co-authored-by: Daniel Weindl <dweindl@users.noreply.github.com> * review comments * document hybridization table * update sciml repo * refactor and add documentation to nn code * print missing components * add tests * refactor _initialize_model_with_nominal_values * fix doc, canonical spelling * Update python/sdist/amici/jax/model.py Co-authored-by: Daniel Weindl <dweindl@users.noreply.github.com> * Update python/sdist/amici/jax/nn.py Co-authored-by: Daniel Weindl <dweindl@users.noreply.github.com> * Update python/sdist/amici/jax/nn.py Co-authored-by: Daniel Weindl <dweindl@users.noreply.github.com> * Update python/sdist/amici/jax/petab.py Co-authored-by: Daniel Weindl <dweindl@users.noreply.github.com> * Update python/sdist/amici/petab/petab_import.py Co-authored-by: Daniel Weindl <dweindl@users.noreply.github.com> * Apply suggestion from @dweindl Co-authored-by: Daniel Weindl <dweindl@users.noreply.github.com> * Apply suggestion from @dweindl Co-authored-by: Daniel Weindl <dweindl@users.noreply.github.com> * Apply suggestion from @dweindl Co-authored-by: Daniel Weindl <dweindl@users.noreply.github.com> * pre-commit fixes * fixup * update changelog * update testsuite, add support for cat * remove gitmodule * refactor testsuite * remove initialization tests * Update test_sciml.py --------- Co-authored-by: Branwen Snelling <branwen.snelling@crick.ac.uk> Co-authored-by: Daniel Weindl <dweindl@users.noreply.github.com>
1 parent 45f6650 commit 4360d5e

26 files changed

+2346
-59
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
name: PEtab
2+
on:
3+
push:
4+
branches:
5+
- develop
6+
- main
7+
pull_request:
8+
branches:
9+
- main
10+
- develop
11+
- jax_sciml
12+
merge_group:
13+
workflow_dispatch:
14+
15+
jobs:
16+
build:
17+
name: PEtab SciML Testsuite
18+
19+
runs-on: ubuntu-latest
20+
21+
env:
22+
ENABLE_GCOV_COVERAGE: TRUE
23+
24+
strategy:
25+
matrix:
26+
python-version: ["3.12"]
27+
28+
steps:
29+
- name: Set up Python ${{ matrix.python-version }}
30+
uses: actions/setup-python@v5
31+
with:
32+
python-version: ${{ matrix.python-version }}
33+
34+
- uses: actions/checkout@v4
35+
with:
36+
fetch-depth: 20
37+
38+
# todo, update after https://github.com/sebapersson/petab_sciml_testsuite/issues/14 is merged
39+
- name: Download PEtab SciML test suite
40+
run: |
41+
git clone --depth 1 --branch main \
42+
https://github.com/FFroehlich/petab_sciml_testsuite \
43+
tests/sciml/testsuite
44+
45+
- name: Install apt dependencies
46+
uses: ./.github/actions/install-apt-dependencies
47+
48+
# install dependencies
49+
- name: apt
50+
run: |
51+
sudo apt-get update \
52+
&& sudo apt-get install -y python3-venv
53+
54+
- run: |
55+
echo "${HOME}/.local/bin/" >> $GITHUB_PATH
56+
57+
# install AMICI
58+
- name: Install python package
59+
run: scripts/installAmiciSource.sh
60+
61+
- name: Install petab
62+
run: |
63+
source ./venv/bin/activate \
64+
&& pip3 install wheel pytest shyaml pytest-cov
65+
66+
# retrieve test models
67+
- name: Download and install PEtab SciML
68+
run: |
69+
source ./venv/bin/activate \
70+
&& python -m pip install git+https://github.com/petab-dev/petab_sciml.git@main#subdirectory=src/python \
71+
72+
73+
- name: Install petab
74+
run: |
75+
source ./venv/bin/activate \
76+
&& python3 -m pip uninstall -y petab \
77+
&& python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@sciml \
78+
79+
- name: Run PEtab SciML testsuite
80+
run: |
81+
source ./venv/bin/activate \
82+
&& pytest --cov-report=xml:coverage_petab_sciml.xml \
83+
--cov=amici tests/sciml/test_sciml.py
84+
85+
- name: Codecov
86+
if: github.event_name == 'pull_request' || github.repository_owner == 'AMICI-dev'
87+
uses: codecov/codecov-action@v5
88+
with:
89+
token: ${{ secrets.CODECOV_TOKEN }}
90+
file: coverage_petab_sciml.xml
91+
flags: petab_sciml
92+
fail_ci_if_error: true

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ models/model_calvetti/build/*
3838

3939
amici_models/
4040

41+
# PEtab SciML test suite (downloaded dynamically)
42+
tests/sciml/testsuite/
43+
4144
simulate_model_*_hdf.m
4245
simulate_model_*.m
4346

@@ -196,3 +199,4 @@ debug/*
196199
tests/benchmark_models/cache_fiddy/*
197200
venv/*
198201
.coverage
202+
tests/sciml/models/*

.gitmodules

Whitespace-only changes.

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ See also our [versioning policy](https://amici.readthedocs.io/en/latest/versioni
7474
This only works on shared file systems, as the solver state is stored in a
7575
temporary HDF5 file.
7676
* `amici.ExpData` is now picklable.
77+
* Implemented support for the [PEtab SciML](https://github.com/PEtab-dev/petab_sciml)
78+
extension for the JAX interface.
7779
* The import function `sbml2amici`, `pysb2amici`, and `antimony2amici` now
7880
return an instance of the generated model class if called with `compile=True`
7981
(default).

doc/conf.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import exhale_multiproject_monkeypatch # noqa: F401
3232

3333
# need to import before setting typing.TYPE_CHECKING=True, fails otherwise
34+
3435
import amici
3536
import pandas as pd # noqa: F401
3637
import sympy as sp # noqa: F401
@@ -365,6 +366,11 @@ def install_doxygen():
365366
"ExpDataPtrVector": ":class:`amici.amici.ExpData`",
366367
}
367368

369+
# TODO: alias for forward type definition, remove after release of petab_sciml
370+
autodoc_type_aliases = {
371+
"NNModel": "petab_sciml.NNModel",
372+
}
373+
368374

369375
def process_docstring(app, what, name, obj, options, lines):
370376
# only apply in the amici.amici module

doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,13 @@
9393
"metadata": {},
9494
"outputs": [],
9595
"source": [
96-
"# Access the results\n",
97-
"results"
96+
"# Define the simulation condition\n",
97+
"simulation_condition = (\"model1_data1\",)\n",
98+
"\n",
99+
"# Access the results for the specified condition\n",
100+
"ic = results[\"simulation_conditions\"].index(simulation_condition)\n",
101+
"print(\"llh: \", results[\"llh\"][ic])\n",
102+
"print(\"state variables: \", results[\"x\"][ic, :])"
98103
]
99104
},
100105
{
@@ -356,7 +361,7 @@
356361
"metadata": {},
357362
"outputs": [],
358363
"source": [
359-
"grad._my"
364+
"grad._my[ic, :]"
360365
]
361366
},
362367
{
@@ -393,7 +398,7 @@
393398
"nps = jax_problem._np_numeric[ic, :]\n",
394399
"\n",
395400
"# Load parameters for the specified condition\n",
396-
"p = jax_problem.load_parameters(simulation_condition[0])\n",
401+
"p = jax_problem.load_model_parameters(simulation_condition[0])\n",
397402
"\n",
398403
"\n",
399404
"# Define a function to compute the gradient with respect to dynamic timepoints\n",
@@ -612,16 +617,16 @@
612617
]
613618
},
614619
{
615-
"cell_type": "code",
616-
"execution_count": null,
617-
"id": "b8382b0b2b68f49e",
618620
"metadata": {},
619-
"outputs": [],
621+
"cell_type": "code",
620622
"source": [
621623
"# Profile gradient computation using forward sensitivity analysis\n",
622624
"solver.set_sensitivity_order(amici.SensitivityOrder.first)\n",
623625
"solver.set_sensitivity_method(amici.SensitivityMethod.forward)"
624-
]
626+
],
627+
"id": "81fe95a6e7f613f1",
628+
"outputs": [],
629+
"execution_count": null
625630
},
626631
{
627632
"cell_type": "code",
@@ -687,8 +692,7 @@
687692
"mimetype": "text/x-python",
688693
"name": "python",
689694
"nbconvert_exporter": "python",
690-
"pygments_lexer": "ipython3",
691-
"version": "3.13.0"
695+
"pygments_lexer": "ipython3"
692696
}
693697
},
694698
"nbformat": 4,

doc/rtd_requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ setuptools>=67.7.2
77
# https://github.com/pysb/pysb/pull/599
88
# for building the documentation, we don't care whether this fully works
99
git+https://github.com/pysb/pysb@0afeaab385e9a1d813ecf6fdaf0153f4b91358af
10+
# For forward type definition in generate_equinox
11+
git+https://github.com/PEtab-dev/petab_sciml.git@727d177fd3f85509d0bdcc278b672e9eeafd2384#subdirectory=src/python
1012
matplotlib>=3.7.1
1113
optax
1214
nbsphinx
@@ -16,6 +18,7 @@ sphinx_rtd_theme>=1.2.0
1618
petab[vis]>=0.2.0
1719
sphinx-autodoc-typehints
1820
ipython>=8.13.2
21+
h5py>=3.14.0
1922
breathe>=4.35.0
2023
exhale>=0.3.7
2124
-e git+https://github.com/mithro/sphinx-contrib-mithro#egg=sphinx-contrib-exhale-multiproject&subdirectory=sphinx-contrib-exhale-multiproject

python/sdist/amici/de_export.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def __init__(
165165
allow_reinit_fixpar_initcond: bool | None = True,
166166
generate_sensitivity_code: bool | None = True,
167167
model_name: str | None = "model",
168+
hybridization: dict | None = None,
168169
):
169170
"""
170171
Generate AMICI C++ files for the DE provided to the constructor.
@@ -196,6 +197,10 @@ def __init__(
196197
197198
:param model_name:
198199
name of the model to be used during code generation
200+
201+
:param hybridization:
202+
dict representation of the hybridization information in the PEtab YAML file, see
203+
https://petab-sciml.readthedocs.io/latest/format.html#problem-yaml-file
199204
"""
200205
set_log_level(logger, verbose)
201206

@@ -237,6 +242,7 @@ def __init__(
237242
self.allow_reinit_fixpar_initcond: bool = allow_reinit_fixpar_initcond
238243
self._build_hints = set()
239244
self.generate_sensitivity_code: bool = generate_sensitivity_code
245+
self.hybridisation = hybridization
240246

241247
@log_execution_time("generating cpp code", logger)
242248
def generate_model_code(self) -> None:

0 commit comments

Comments
 (0)