Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
"\n",
"To begin, we will import a model using [PEtab](https://petab.readthedocs.io). For this demonstration, we will utilize the [Benchmark Collection](https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab), which provides a diverse set of models. For more information on importing PEtab models, refer to the corresponding [PEtab notebook](https://amici.readthedocs.io/en/latest/petab.html).\n",
"\n",
"In this tutorial, we will import the Böhm model from the Benchmark Collection. Using [amici.petab_import](https://amici.readthedocs.io/en/latest/generated/amici.petab_import.html#amici.petab_import.import_petab_problem), we will load the PEtab problem. To create a [JAXModel](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXModel) instead of a standard AMICI model, we set the `jax` parameter to `True`.\n"
"In this tutorial, we will import the Böhm model from the Benchmark Collection. Using [amici.petab_import](https://amici.readthedocs.io/en/latest/generated/amici.petab_import.html#amici.petab_import.import_petab_problem), we will load the PEtab problem. To create a [JAXProblem](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXProblem) instead of a standard AMICI model, we set the `jax` parameter to `True`.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c71c96da0da3144a",
"id": "e5ade24ee5aca07c",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -44,35 +44,32 @@
"# Load the PEtab problem from the YAML file\n",
"petab_problem = petab.Problem.from_yaml(yaml_url)\n",
"\n",
"# Import the PEtab problem as a JAX-compatible AMICI model\n",
"jax_model = import_petab_problem(\n",
"# Import the PEtab problem as a JAX-compatible AMICI problem\n",
"jax_problem = import_petab_problem(\n",
" petab_problem,\n",
" verbose=False, # no text output\n",
" jax=True, # return jax model\n",
" jax=True, # return jax problem\n",
")"
]
},
{
"cell_type": "markdown",
"id": "7e0f1c27bd71ee1f",
"id": "5d0dce1427a7883f",
"metadata": {},
"source": [
"## Simulation\n",
"\n",
"In principle, we can already use this model for simulation using the [simulate_condition](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXModel.simulate_condition) method. However, this approach can be cumbersome as timepoints, data etc. need to be specified manually. Instead, we process the PEtab problem into a [JAXProblem](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXProblem), which enables efficient simulation using [amici.jax.run_simulations]((https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.run_simulations)."
"We can now run efficient simulation using [amici.jax.run_simulations]((https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.run_simulations)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ccecc9a29acc7b73",
"id": "4a99eccd0607793e",
"metadata": {},
"outputs": [],
"source": [
"from amici.jax import JAXProblem, run_simulations\n",
"\n",
"# Create a JAXProblem from the JAX model and PEtab problem\n",
"jax_problem = JAXProblem(jax_model, petab_problem)\n",
"from amici.jax import run_simulations\n",
"\n",
"# Run simulations and compute the log-likelihood\n",
"llh, results = run_simulations(jax_problem)"
Expand Down Expand Up @@ -141,7 +138,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "72f1ed397105e14a",
"id": "78f06ff19626df7d",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -168,7 +165,9 @@
" for ix in range(results[\"x\"].shape[2]):\n",
" time_points = np.array(results[\"ts\"][ic, :])\n",
" state_values = np.array(results[\"x\"][ic, :, ix])\n",
" plt.plot(time_points, state_values, label=jax_model.state_ids[ix])\n",
" plt.plot(\n",
" time_points, state_values, label=jax_problem.model.state_ids[ix]\n",
" )\n",
"\n",
" # Add labels, legend, and grid\n",
" plt.xlabel(\"Time\")\n",
Expand Down
16 changes: 11 additions & 5 deletions python/sdist/amici/petab/petab_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def import_petab_problem(
non_estimated_parameters_as_constants=True,
jax=False,
**kwargs,
) -> "amici.Model | amici.JAXModel":
) -> "amici.Model | amici.jax.JAXProblem":
"""
Create an AMICI model for a PEtab problem.

Expand All @@ -73,16 +73,17 @@ def import_petab_problem(
parameters are required, this should be set to ``False``.

:param jax:
Whether to load the jax version of the model. Note that this disables
compilation of the model module unless `compile` is set to `True`.
Whether to create a JAX-based problem. If ``True``, returns a
:class:`amici.jax.JAXProblem` instance. If ``False``, returns a
standard AMICI model.

:param kwargs:
Additional keyword arguments to be passed to
:meth:`amici.sbml_import.SbmlImporter.sbml2amici` or
:func:`amici.pysb_import.pysb2amici`, depending on the model type.

:return:
The imported model.
The imported model (if ``jax=False``) or JAX problem (if ``jax=True``).
"""
if petab_problem.model.type_id not in (MODEL_TYPE_SBML, MODEL_TYPE_PYSB):
raise NotImplementedError(
Expand Down Expand Up @@ -263,13 +264,18 @@ def import_petab_problem(
)

if jax:
from amici.jax import JAXProblem

model = model_module.Model()

logger.info(
f"Successfully loaded jax model {model_name} "
f"from {model_output_dir}."
)
return model

# Create and return JAXProblem
logger.info(f"Successfully created JAXProblem for {model_name}.")
return JAXProblem(model, petab_problem)

model = model_module.get_model()
check_model(amici_model=model, petab_problem=petab_problem)
Expand Down
9 changes: 3 additions & 6 deletions python/tests/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,20 +276,18 @@ def test_preequilibration_failure(lotka_volterra): # noqa: F811
petab_problem = lotka_volterra
# oscillating system, preequilibation should fail when interaction is active
with TemporaryDirectoryWinSafe(prefix="normal") as model_dir:
jax_model = import_petab_problem(
jax_problem = import_petab_problem(
petab_problem, jax=True, model_output_dir=model_dir
)
jax_problem = JAXProblem(jax_model, petab_problem)
r = run_simulations(jax_problem)
assert not np.isinf(r[0].item())
petab_problem.measurement_df[PREEQUILIBRATION_CONDITION_ID] = (
petab_problem.measurement_df[SIMULATION_CONDITION_ID]
)
with TemporaryDirectoryWinSafe(prefix="failure") as model_dir:
jax_model = import_petab_problem(
jax_problem = import_petab_problem(
petab_problem, jax=True, model_output_dir=model_dir
)
jax_problem = JAXProblem(jax_model, petab_problem)
r = run_simulations(jax_problem)
assert np.isinf(r[0].item())

Expand All @@ -300,10 +298,9 @@ def test_serialisation(lotka_volterra): # noqa: F811
with TemporaryDirectoryWinSafe(
prefix=petab_problem.model.model_id
) as model_dir:
jax_model = import_petab_problem(
jax_problem = import_petab_problem(
petab_problem, jax=True, model_output_dir=model_dir
)
jax_problem = JAXProblem(jax_model, petab_problem)
# change parameters to random values to test serialisation
jax_problem.update_parameters(
jax_problem.parameters
Expand Down
5 changes: 2 additions & 3 deletions tests/benchmark_models/test_petab_benchmark_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import jax.numpy as jnp
import numpy as np
import pytest
from amici.jax.petab import JAXProblem, run_simulations
from amici.jax.petab import run_simulations
from amici.petab.petab_import import import_petab_problem
from amici.petab.simulations import LLH, SLLH, simulate_petab
from beartype import beartype
Expand Down Expand Up @@ -83,12 +83,11 @@ def test_jax_llh(benchmark_problem):
r_amici = simulate_amici()
llh_amici = r_amici[LLH]

jax_model = import_petab_problem(
jax_problem = import_petab_problem(
petab_problem,
model_output_dir=benchmark_outdir / (problem_id + "_jax"),
jax=True,
)
jax_problem = JAXProblem(jax_model, petab_problem)
if problem_parameters:
jax_problem = eqx.tree_at(
lambda x: x.parameters,
Expand Down
9 changes: 6 additions & 3 deletions tests/petab_test_suite/test_petab_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,20 @@ def _test_case(case, model_type, version, jax):
f"petab_{model_type}_test_case_{case}_{version.replace('.', '_')}"
)
model_output_dir = f"amici_models/{model_name}" + ("_jax" if jax else "")
model = import_petab_problem(
imported = import_petab_problem(
petab_problem=problem,
model_output_dir=model_output_dir,
model_name=model_name,
compile_=True,
jax=jax,
)
if jax:
from amici.jax import JAXProblem, petab_simulate, run_simulations
from amici.jax import petab_simulate, run_simulations

steady_state_event = diffrax.steady_state_event(rtol=1e-6, atol=1e-6)
jax_problem = JAXProblem(model, problem)
jax_problem = (
imported # import_petab_problem returns JAXProblem when jax=True
)
llh, ret = run_simulations(
jax_problem, steady_state_event=steady_state_event
)
Expand All @@ -89,6 +91,7 @@ def _test_case(case, model_type, version, jax):
columns={petab.SIMULATION: petab.MEASUREMENT}, inplace=True
)
else:
model = imported # import_petab_problem returns Model when jax=False
solver = model.create_solver()
solver.set_steady_state_tolerance_factor(1.0)
problem_parameters = dict(
Expand Down
5 changes: 1 addition & 4 deletions tests/sciml/test_sciml.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import petab.v1 as petab
import pytest
from amici.jax import (
JAXProblem,
generate_equinox,
petab_simulate,
run_simulations,
Expand Down Expand Up @@ -197,15 +196,13 @@ def test_ude(test):

petab_yaml["format_version"] = "2.0.0" # TODO: fixme
petab_problem = Problem.from_yaml(petab_yaml)
jax_model = import_petab_problem(
jax_problem = import_petab_problem(
petab_problem,
model_output_dir=Path(__file__).parent / "models" / test,
compile_=True,
jax=True,
)

jax_problem = JAXProblem(jax_model, petab_problem)

# llh
llh, r = run_simulations(jax_problem)
np.testing.assert_allclose(
Expand Down
Loading