diff --git a/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb b/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb index bfd26757d4..4a2cc76ec4 100644 --- a/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb +++ b/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb @@ -21,13 +21,13 @@ "\n", "To begin, we will import a model using [PEtab](https://petab.readthedocs.io). For this demonstration, we will utilize the [Benchmark Collection](https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab), which provides a diverse set of models. For more information on importing PEtab models, refer to the corresponding [PEtab notebook](https://amici.readthedocs.io/en/latest/petab.html).\n", "\n", - "In this tutorial, we will import the Böhm model from the Benchmark Collection. Using [amici.petab_import](https://amici.readthedocs.io/en/latest/generated/amici.petab_import.html#amici.petab_import.import_petab_problem), we will load the PEtab problem. To create a [JAXModel](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXModel) instead of a standard AMICI model, we set the `jax` parameter to `True`.\n" + "In this tutorial, we will import the Böhm model from the Benchmark Collection. Using [amici.petab_import](https://amici.readthedocs.io/en/latest/generated/amici.petab_import.html#amici.petab_import.import_petab_problem), we will load the PEtab problem. To create a [JAXProblem](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXProblem) instead of a standard AMICI model, we set the `jax` parameter to `True`.\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "c71c96da0da3144a", + "id": "e5ade24ee5aca07c", "metadata": {}, "outputs": [], "source": [ @@ -44,35 +44,32 @@ "# Load the PEtab problem from the YAML file\n", "petab_problem = petab.Problem.from_yaml(yaml_url)\n", "\n", - "# Import the PEtab problem as a JAX-compatible AMICI model\n", - "jax_model = import_petab_problem(\n", + "# Import the PEtab problem as a JAX-compatible AMICI problem\n", + "jax_problem = import_petab_problem(\n", " petab_problem,\n", " verbose=False, # no text output\n", - " jax=True, # return jax model\n", + " jax=True, # return jax problem\n", ")" ] }, { "cell_type": "markdown", - "id": "7e0f1c27bd71ee1f", + "id": "5d0dce1427a7883f", "metadata": {}, "source": [ "## Simulation\n", "\n", - "In principle, we can already use this model for simulation using the [simulate_condition](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXModel.simulate_condition) method. However, this approach can be cumbersome as timepoints, data etc. need to be specified manually. Instead, we process the PEtab problem into a [JAXProblem](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXProblem), which enables efficient simulation using [amici.jax.run_simulations]((https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.run_simulations)." + "We can now run efficient simulation using [amici.jax.run_simulations]((https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.run_simulations)." ] }, { "cell_type": "code", "execution_count": null, - "id": "ccecc9a29acc7b73", + "id": "4a99eccd0607793e", "metadata": {}, "outputs": [], "source": [ - "from amici.jax import JAXProblem, run_simulations\n", - "\n", - "# Create a JAXProblem from the JAX model and PEtab problem\n", - "jax_problem = JAXProblem(jax_model, petab_problem)\n", + "from amici.jax import run_simulations\n", "\n", "# Run simulations and compute the log-likelihood\n", "llh, results = run_simulations(jax_problem)" @@ -141,7 +138,7 @@ { "cell_type": "code", "execution_count": null, - "id": "72f1ed397105e14a", + "id": "78f06ff19626df7d", "metadata": {}, "outputs": [], "source": [ @@ -168,7 +165,9 @@ " for ix in range(results[\"x\"].shape[2]):\n", " time_points = np.array(results[\"ts\"][ic, :])\n", " state_values = np.array(results[\"x\"][ic, :, ix])\n", - " plt.plot(time_points, state_values, label=jax_model.state_ids[ix])\n", + " plt.plot(\n", + " time_points, state_values, label=jax_problem.model.state_ids[ix]\n", + " )\n", "\n", " # Add labels, legend, and grid\n", " plt.xlabel(\"Time\")\n", diff --git a/python/sdist/amici/petab/petab_import.py b/python/sdist/amici/petab/petab_import.py index 32cefb0845..ac49894685 100644 --- a/python/sdist/amici/petab/petab_import.py +++ b/python/sdist/amici/petab/petab_import.py @@ -46,7 +46,7 @@ def import_petab_problem( non_estimated_parameters_as_constants=True, jax=False, **kwargs, -) -> "amici.Model | amici.JAXModel": +) -> "amici.Model | amici.jax.JAXProblem": """ Create an AMICI model for a PEtab problem. @@ -73,8 +73,9 @@ def import_petab_problem( parameters are required, this should be set to ``False``. :param jax: - Whether to load the jax version of the model. Note that this disables - compilation of the model module unless `compile` is set to `True`. + Whether to create a JAX-based problem. If ``True``, returns a + :class:`amici.jax.JAXProblem` instance. If ``False``, returns a + standard AMICI model. :param kwargs: Additional keyword arguments to be passed to @@ -82,7 +83,7 @@ def import_petab_problem( :func:`amici.pysb_import.pysb2amici`, depending on the model type. :return: - The imported model. + The imported model (if ``jax=False``) or JAX problem (if ``jax=True``). """ if petab_problem.model.type_id not in (MODEL_TYPE_SBML, MODEL_TYPE_PYSB): raise NotImplementedError( @@ -263,13 +264,18 @@ def import_petab_problem( ) if jax: + from amici.jax import JAXProblem + model = model_module.Model() logger.info( f"Successfully loaded jax model {model_name} " f"from {model_output_dir}." ) - return model + + # Create and return JAXProblem + logger.info(f"Successfully created JAXProblem for {model_name}.") + return JAXProblem(model, petab_problem) model = model_module.get_model() check_model(amici_model=model, petab_problem=petab_problem) diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index 055d0c9a61..37073f6701 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -276,20 +276,18 @@ def test_preequilibration_failure(lotka_volterra): # noqa: F811 petab_problem = lotka_volterra # oscillating system, preequilibation should fail when interaction is active with TemporaryDirectoryWinSafe(prefix="normal") as model_dir: - jax_model = import_petab_problem( + jax_problem = import_petab_problem( petab_problem, jax=True, model_output_dir=model_dir ) - jax_problem = JAXProblem(jax_model, petab_problem) r = run_simulations(jax_problem) assert not np.isinf(r[0].item()) petab_problem.measurement_df[PREEQUILIBRATION_CONDITION_ID] = ( petab_problem.measurement_df[SIMULATION_CONDITION_ID] ) with TemporaryDirectoryWinSafe(prefix="failure") as model_dir: - jax_model = import_petab_problem( + jax_problem = import_petab_problem( petab_problem, jax=True, model_output_dir=model_dir ) - jax_problem = JAXProblem(jax_model, petab_problem) r = run_simulations(jax_problem) assert np.isinf(r[0].item()) @@ -300,10 +298,9 @@ def test_serialisation(lotka_volterra): # noqa: F811 with TemporaryDirectoryWinSafe( prefix=petab_problem.model.model_id ) as model_dir: - jax_model = import_petab_problem( + jax_problem = import_petab_problem( petab_problem, jax=True, model_output_dir=model_dir ) - jax_problem = JAXProblem(jax_model, petab_problem) # change parameters to random values to test serialisation jax_problem.update_parameters( jax_problem.parameters diff --git a/tests/benchmark_models/test_petab_benchmark_jax.py b/tests/benchmark_models/test_petab_benchmark_jax.py index 831c9849a1..ef0c4b0680 100644 --- a/tests/benchmark_models/test_petab_benchmark_jax.py +++ b/tests/benchmark_models/test_petab_benchmark_jax.py @@ -7,7 +7,7 @@ import jax.numpy as jnp import numpy as np import pytest -from amici.jax.petab import JAXProblem, run_simulations +from amici.jax.petab import run_simulations from amici.petab.petab_import import import_petab_problem from amici.petab.simulations import LLH, SLLH, simulate_petab from beartype import beartype @@ -83,12 +83,11 @@ def test_jax_llh(benchmark_problem): r_amici = simulate_amici() llh_amici = r_amici[LLH] - jax_model = import_petab_problem( + jax_problem = import_petab_problem( petab_problem, model_output_dir=benchmark_outdir / (problem_id + "_jax"), jax=True, ) - jax_problem = JAXProblem(jax_model, petab_problem) if problem_parameters: jax_problem = eqx.tree_at( lambda x: x.parameters, diff --git a/tests/petab_test_suite/test_petab_suite.py b/tests/petab_test_suite/test_petab_suite.py index 43c503c6d4..fae48528f8 100755 --- a/tests/petab_test_suite/test_petab_suite.py +++ b/tests/petab_test_suite/test_petab_suite.py @@ -64,7 +64,7 @@ def _test_case(case, model_type, version, jax): f"petab_{model_type}_test_case_{case}_{version.replace('.', '_')}" ) model_output_dir = f"amici_models/{model_name}" + ("_jax" if jax else "") - model = import_petab_problem( + imported = import_petab_problem( petab_problem=problem, model_output_dir=model_output_dir, model_name=model_name, @@ -72,10 +72,12 @@ def _test_case(case, model_type, version, jax): jax=jax, ) if jax: - from amici.jax import JAXProblem, petab_simulate, run_simulations + from amici.jax import petab_simulate, run_simulations steady_state_event = diffrax.steady_state_event(rtol=1e-6, atol=1e-6) - jax_problem = JAXProblem(model, problem) + jax_problem = ( + imported # import_petab_problem returns JAXProblem when jax=True + ) llh, ret = run_simulations( jax_problem, steady_state_event=steady_state_event ) @@ -89,6 +91,7 @@ def _test_case(case, model_type, version, jax): columns={petab.SIMULATION: petab.MEASUREMENT}, inplace=True ) else: + model = imported # import_petab_problem returns Model when jax=False solver = model.create_solver() solver.set_steady_state_tolerance_factor(1.0) problem_parameters = dict( diff --git a/tests/sciml/test_sciml.py b/tests/sciml/test_sciml.py index ae541b0d36..76961f556d 100644 --- a/tests/sciml/test_sciml.py +++ b/tests/sciml/test_sciml.py @@ -14,7 +14,6 @@ import petab.v1 as petab import pytest from amici.jax import ( - JAXProblem, generate_equinox, petab_simulate, run_simulations, @@ -197,15 +196,13 @@ def test_ude(test): petab_yaml["format_version"] = "2.0.0" # TODO: fixme petab_problem = Problem.from_yaml(petab_yaml) - jax_model = import_petab_problem( + jax_problem = import_petab_problem( petab_problem, model_output_dir=Path(__file__).parent / "models" / test, compile_=True, jax=True, ) - jax_problem = JAXProblem(jax_model, petab_problem) - # llh llh, r = run_simulations(jax_problem) np.testing.assert_allclose(