Skip to content

Commit 4eebf16

Browse files
authored
Split jax/cpp PEtab benchmark tests (#2794)
* Remove import guard from PEtab benchmark * Apply suggestions from code review * remove spurious inits * move and update AGENTS.md * Move benchmark fixture to conftest * Update tests/benchmark-models/conftest.py * Update tests/benchmark-models/test_petab_benchmark_jax.py
1 parent dafab72 commit 4eebf16

File tree

6 files changed

+306
-153
lines changed

6 files changed

+306
-153
lines changed

.github/workflows/test_benchmark_collection_models.yml

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ on:
1414
- cron: '48 4 * * *'
1515

1616
jobs:
17-
build:
18-
name: Benchmark Collection
17+
cpp:
18+
name: Benchmark Collection CPP
1919

2020
runs-on: ubuntu-24.04
2121

@@ -66,11 +66,11 @@ jobs:
6666
env:
6767
AMICI_PARALLEL_COMPILE: ""
6868
run: |
69-
cd tests/benchmark-models && pytest \
70-
--durations=10 \
71-
--cov=amici \
72-
--cov-report=xml:"coverage_py.xml" \
73-
--cov-append \
69+
cd tests/benchmark-models && pytest test_petab_benchmark.py \
70+
--durations=10 \
71+
--cov=amici \
72+
--cov-report=xml:"coverage_py.xml" \
73+
--cov-append \
7474
7575
- name: Codecov Python
7676
if: github.event_name == 'pull_request' || github.repository_owner == 'AMICI-dev'
@@ -93,3 +93,67 @@ jobs:
9393
path: |
9494
tests/benchmark-models/computation_times.csv
9595
tests/benchmark-models/computation_times.png
96+
97+
jax:
98+
name: Benchmark Collection JAX
99+
runs-on: ubuntu-24.04
100+
strategy:
101+
fail-fast: false
102+
matrix:
103+
python-version: [ "3.12" ]
104+
extract_subexpressions: ["true", "false"]
105+
env:
106+
AMICI_EXTRACT_CSE: ${{ matrix.extract_subexpressions }}
107+
steps:
108+
- name: Set up Python ${{ matrix.python-version }}
109+
uses: actions/setup-python@v5
110+
with:
111+
python-version: ${{ matrix.python-version }}
112+
113+
- uses: actions/checkout@v4
114+
with:
115+
fetch-depth: 20
116+
117+
- name: Install apt dependencies
118+
uses: ./.github/actions/install-apt-dependencies
119+
120+
- run: echo "${HOME}/.local/bin/" >> $GITHUB_PATH
121+
122+
- name: Create AMICI sdist
123+
run: pip3 install build && cd python/sdist && python3 -m build --sdist
124+
125+
- name: Install AMICI sdist
126+
run: |
127+
pip3 install --user petab[vis] && \
128+
pip3 install -v --user \
129+
$(ls -t python/sdist/dist/amici-*.tar.gz | head -1)[petab,test,vis,jax]
130+
131+
- name: Install test dependencies
132+
run: |
133+
python3 -m pip uninstall -y petab && python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@develop \
134+
&& python3 -m pip install -U sympy \
135+
&& python3 -m pip install git+https://github.com/ICB-DCM/fiddy.git
136+
137+
- name: Download benchmark collection
138+
run: |
139+
pip install git+https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab.git@master#subdirectory=src/python
140+
141+
- name: Run JAX tests
142+
env:
143+
AMICI_PARALLEL_COMPILE: ""
144+
run: |
145+
cd tests/benchmark-models && pytest test_petab_benchmark_jax.py \
146+
--durations=10 \
147+
--cov=amici \
148+
--cov-report=xml:"coverage_py.xml" \
149+
--cov-append
150+
151+
- name: Codecov Python
152+
if: github.event_name == 'pull_request' || github.repository_owner == 'AMICI-dev'
153+
uses: codecov/codecov-action@v5
154+
with:
155+
token: ${{ secrets.CODECOV_TOKEN }}
156+
files: coverage_py.xml
157+
flags: python
158+
fail_ci_if_error: true
159+
verbose: true

doc/AGENTS.md

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Project Agents.md Guide for OpenAI Codex
2+
3+
This Agents.md file provides comprehensive guidance for OpenAI Codex and other AI agents working with this codebase.
4+
5+
## Project Structure for OpenAI Codex Navigation
6+
7+
AMICI is a python package that uses SWIG to generate python bindings to C++ code. There are also a deprecated matlab interface to the C++ code, which will be removed at some point in the future.
8+
9+
- `/binder`: binder configuration
10+
- `/cmake`: various cmake utility functions
11+
- `/container`: docker configuration
12+
- `/doc`: high level documentation, all API documentation is automatically generated using doxygen/sphinx
13+
- `/include`: C++ header files
14+
- `/matlab`: matlab interface
15+
- `/models`: pre-generated c++ models for testing
16+
- `/python`: python source code
17+
- `/benchmark`: helper scripts for benchmarking
18+
- `/sdist`: python package
19+
- `/amici`: python module
20+
- `/_codegen` helper functions for C++ code generation
21+
- `/debugging` helpfer functions for debugging simulation failures
22+
- `/include` symlink to C++ headers
23+
- `/jax` code for JAX backend, alternative to C++ backend
24+
- `/petab` interface for simulating models/problems in the PEtab format
25+
- `/testing` helpfer functions for tests
26+
- `/tests`: self contained python package specific tests
27+
- `/scripts`: bash scripts for testing and installing the package
28+
- `/src`: C++ source files
29+
- `/swig`: definition of the SWIG interface
30+
- `/tests`: C++ tests and python tests that require third party resources, base directory contains tests for the SBML testsuite
31+
- `/benchmark-models` regression tests for the PEtab benchmark collection
32+
- `/cpp` C++ tests
33+
- `/generateTestConfg` helper functions to generate python test configurations for pregenerated C++ models
34+
- `/performace` performance tests for the PEtab benchmark collection
35+
- `/petab_test_suite` regression tests for the PEtab testsuite
36+
37+
## Contribution Guidelines
38+
39+
Please see instructions in `doc/CONTRIBUTING.md`
40+
41+
# Agent Instructions
42+
43+
To ensure all tools and dependencies are available, activate the virtual environment before running any commands:
44+
45+
```bash
46+
source ./venv/bin/activate
47+
```
48+
49+
This project uses `pre-commit` for linting and `pytest` for tests. Run them on changed files whenever you make modifications.
50+
51+
When running the tests locally, change into the test directory first:
52+
53+
```bash
54+
cd tests/benchmark-models
55+
pytest test_petab_benchmark.py
56+
pytest test_petab_benchmark_jax.py
57+
```
58+
59+
To quickly verify the benchmark tests, you can limit execution to a small model:
60+
61+
```bash
62+
pytest -k Boehm_JProteomeRes2014 test_petab_benchmark.py
63+
```

doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ def install_doxygen():
259259
"MATLAB_.md",
260260
"CPP_.md",
261261
"gfx",
262+
"AGENTS.md",
262263
]
263264

264265
# The name of the Pygments (syntax highlighting) style to use.

tests/benchmark-models/conftest.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import copy
2+
from pathlib import Path
3+
4+
import pytest
5+
import petab.v1 as petab
6+
from petab.v1.lint import measurement_table_has_timepoint_specific_mappings
7+
8+
import benchmark_models_petab
9+
from amici.petab.petab_import import import_petab_problem
10+
11+
from test_petab_benchmark import problems
12+
13+
script_dir = Path(__file__).parent
14+
repo_root = script_dir.parent.parent
15+
benchmark_outdir = repo_root / "test_bmc"
16+
17+
18+
@pytest.fixture(scope="session", params=problems, ids=problems)
19+
def benchmark_problem(request):
20+
"""Fixture providing model and PEtab problem for a benchmark model."""
21+
problem_id = request.param
22+
petab_problem = benchmark_models_petab.get_problem(problem_id)
23+
flat_petab_problem = copy.deepcopy(petab_problem)
24+
if measurement_table_has_timepoint_specific_mappings(
25+
petab_problem.measurement_df,
26+
):
27+
petab.flatten_timepoint_specific_output_overrides(flat_petab_problem)
28+
29+
amici_model = import_petab_problem(
30+
flat_petab_problem,
31+
model_output_dir=benchmark_outdir / problem_id,
32+
)
33+
return problem_id, flat_petab_problem, petab_problem, amici_model

tests/benchmark-models/test_petab_benchmark.py

Lines changed: 17 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -5,40 +5,38 @@
55
for a subset of the benchmark problems.
66
"""
77

8-
import copy
9-
from functools import partial
108
from pathlib import Path
119

10+
import contextlib
11+
import logging
1212
import os
13-
import fiddy
14-
import amici
13+
from collections import defaultdict
14+
from dataclasses import dataclass, field
15+
1516
import numpy as np
1617
import pandas as pd
1718
import petab.v1 as petab
1819
import pytest
19-
from amici.petab.petab_import import import_petab_problem
20-
import benchmark_models_petab
21-
from collections import defaultdict
22-
from dataclasses import dataclass, field
23-
from amici import SensitivityMethod
24-
from petab.v1.lint import measurement_table_has_timepoint_specific_mappings
25-
from fiddy import MethodId, get_derivative
26-
from fiddy.derivative_check import NumpyIsCloseDerivativeCheck
27-
from fiddy.extensions.amici import simulate_petab_to_cached_functions
28-
from fiddy.success import Consistency
29-
import contextlib
30-
import logging
3120
import yaml
21+
from petab.v1.lint import measurement_table_has_timepoint_specific_mappings
22+
from petab.v1.visualize import plot_problem
23+
24+
import amici
25+
from amici import SensitivityMethod
3226
from amici.logging import get_logger
27+
from amici.petab.petab_import import import_petab_problem
3328
from amici.petab.simulations import (
3429
LLH,
35-
SLLH,
3630
RDATAS,
3731
rdatas_to_measurement_df,
3832
simulate_petab,
3933
)
40-
41-
from petab.v1.visualize import plot_problem
34+
import benchmark_models_petab
35+
import fiddy
36+
from fiddy import MethodId, get_derivative
37+
from fiddy.derivative_check import NumpyIsCloseDerivativeCheck
38+
from fiddy.extensions.amici import simulate_petab_to_cached_functions
39+
from fiddy.success import Consistency
4240

4341

4442
# Enable various debug output
@@ -241,133 +239,6 @@ class GradientCheckSettings:
241239
)
242240

243241

244-
@pytest.fixture(scope="session", params=problems, ids=problems)
245-
def benchmark_problem(request):
246-
"""Fixture providing model and PEtab problem for a problem from
247-
the benchmark problem collection."""
248-
problem_id = request.param
249-
petab_problem = benchmark_models_petab.get_problem(problem_id)
250-
flat_petab_problem = copy.deepcopy(petab_problem)
251-
if measurement_table_has_timepoint_specific_mappings(
252-
petab_problem.measurement_df,
253-
):
254-
petab.flatten_timepoint_specific_output_overrides(flat_petab_problem)
255-
256-
# Setup AMICI objects.
257-
amici_model = import_petab_problem(
258-
flat_petab_problem,
259-
model_output_dir=benchmark_outdir / problem_id,
260-
)
261-
return problem_id, flat_petab_problem, petab_problem, amici_model
262-
263-
264-
@pytest.mark.filterwarnings(
265-
"ignore:The following problem parameters were not used *",
266-
"ignore: The environment variable *",
267-
"ignore:Adjoint sensitivity analysis for models with discontinuous ",
268-
)
269-
def test_jax_llh(benchmark_problem):
270-
import jax
271-
import equinox as eqx
272-
import jax.numpy as jnp
273-
from amici.jax.petab import run_simulations, JAXProblem
274-
275-
jax.config.update("jax_enable_x64", True)
276-
from beartype import beartype
277-
278-
problem_id, flat_petab_problem, petab_problem, amici_model = (
279-
benchmark_problem
280-
)
281-
282-
amici_solver = amici_model.getSolver()
283-
cur_settings = settings[problem_id]
284-
amici_solver.setAbsoluteTolerance(1e-8)
285-
amici_solver.setRelativeTolerance(1e-8)
286-
amici_solver.setMaxSteps(10_000)
287-
288-
simulate_amici = partial(
289-
simulate_petab,
290-
petab_problem=flat_petab_problem,
291-
amici_model=amici_model,
292-
solver=amici_solver,
293-
scaled_parameters=True,
294-
scaled_gradients=True,
295-
log_level=logging.DEBUG,
296-
)
297-
298-
np.random.seed(cur_settings.rng_seed)
299-
300-
problem_parameters = None
301-
if problem_id in problems_for_gradient_check:
302-
point = flat_petab_problem.x_nominal_free_scaled
303-
for _ in range(20):
304-
amici_solver.setSensitivityMethod(amici.SensitivityMethod.adjoint)
305-
amici_solver.setSensitivityOrder(amici.SensitivityOrder.first)
306-
amici_model.setSteadyStateSensitivityMode(
307-
cur_settings.ss_sensitivity_mode
308-
)
309-
point_noise = (
310-
np.random.randn(len(point)) * cur_settings.noise_level
311-
)
312-
point += point_noise # avoid small gradients at nominal value
313-
314-
problem_parameters = dict(
315-
zip(flat_petab_problem.x_free_ids, point)
316-
)
317-
318-
r_amici = simulate_amici(
319-
problem_parameters=problem_parameters,
320-
)
321-
if np.isfinite(r_amici[LLH]):
322-
break
323-
else:
324-
raise RuntimeError("Could not compute expected derivative.")
325-
else:
326-
r_amici = simulate_amici()
327-
llh_amici = r_amici[LLH]
328-
329-
jax_model = import_petab_problem(
330-
petab_problem,
331-
model_output_dir=benchmark_outdir / (problem_id + "_jax"),
332-
jax=True,
333-
)
334-
jax_problem = JAXProblem(jax_model, petab_problem)
335-
if problem_parameters:
336-
jax_problem = eqx.tree_at(
337-
lambda x: x.parameters,
338-
jax_problem,
339-
jnp.array(
340-
[problem_parameters[pid] for pid in jax_problem.parameter_ids]
341-
),
342-
)
343-
344-
if problem_id in problems_for_gradient_check:
345-
beartype(run_simulations)(jax_problem)
346-
(llh_jax, _), sllh_jax = eqx.filter_value_and_grad(
347-
run_simulations, has_aux=True
348-
)(jax_problem)
349-
else:
350-
llh_jax, _ = beartype(run_simulations)(jax_problem)
351-
352-
np.testing.assert_allclose(
353-
llh_jax,
354-
llh_amici,
355-
rtol=1e-3,
356-
atol=1e-3,
357-
err_msg=f"LLH mismatch for {problem_id}",
358-
)
359-
360-
if problem_id in problems_for_gradient_check:
361-
sllh_amici = r_amici[SLLH]
362-
np.testing.assert_allclose(
363-
sllh_jax.parameters,
364-
np.array([sllh_amici[pid] for pid in jax_problem.parameter_ids]),
365-
rtol=1e-2,
366-
atol=1e-2,
367-
err_msg=f"SLLH mismatch for {problem_id}, {dict(zip(jax_problem.parameter_ids, sllh_jax.parameters))}",
368-
)
369-
370-
371242
@pytest.mark.filterwarnings(
372243
"ignore:divide by zero encountered in log",
373244
# https://github.com/AMICI-dev/AMICI/issues/18

0 commit comments

Comments
 (0)