Skip to content

Commit d1bdebd

Browse files
authored
Add Model.simulate (#2963)
Add `Model.simulate()` as a convenience function to run simulations without having to create a `Solver` object explicitly. This is a wrapper for both `amici.run_simulation` and `amici.run_simulations`, depending on the type of the `edata` argument. It also supports passing some `Solver` options as keyword arguments.
1 parent 9704acd commit d1bdebd

File tree

7 files changed

+277
-66
lines changed

7 files changed

+277
-66
lines changed

CHANGELOG.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ See also our [versioning policy](https://amici.readthedocs.io/en/latest/versioni
3434
* For a more consistent API, all function names are now snake_case instead of
3535
camelCase.
3636
* `Model.getSolver` has been renamed to `Model.create_solver`.
37+
* `amici.runAmiciSimulation` and `amici.runAmiciSimulations` have been renamed
38+
to `amici.run_simulation` and `amici.run_simulations`.
3739
* The following deprecated functionality has been removed:
3840
* The complete MATLAB interface has been removed.
3941
* `NonlinearSolverIteration::functional` has been removed,
@@ -54,7 +56,11 @@ See also our [versioning policy](https://amici.readthedocs.io/en/latest/versioni
5456
`DataArray`s include the identifiers and are often more convenient than the
5557
plain numpy arrays. This allows for easy subselection and plotting of the
5658
results, and conversion to DataFrames.
57-
59+
* `Model.simulate()` has been added as a convenience function to run
60+
simulations without having to create a `Solver` object explicitly.
61+
This is a wrapper for both `amici.run_simulation` and
62+
`amici.run_simulations`, depending on the type of the `edata` argument.
63+
It also supports passing some `Solver` options as keyword arguments.
5864

5965
## v0.X Series
6066

doc/examples/getting_started/GettingStarted.ipynb

Lines changed: 38 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
},
2222
{
2323
"cell_type": "code",
24-
"execution_count": 1,
2524
"metadata": {},
26-
"outputs": [],
2725
"source": [
2826
"import amici\n",
2927
"\n",
3028
"sbml_importer = amici.SbmlImporter(\"model_steadystate_scaled.xml\")"
31-
]
29+
],
30+
"outputs": [],
31+
"execution_count": null
3232
},
3333
{
3434
"cell_type": "markdown",
@@ -39,14 +39,14 @@
3939
},
4040
{
4141
"cell_type": "code",
42-
"execution_count": 2,
4342
"metadata": {},
44-
"outputs": [],
4543
"source": [
4644
"model_name = \"model_steadystate\"\n",
4745
"model_dir = \"model_dir\"\n",
4846
"sbml_importer.sbml2amici(model_name, model_dir)"
49-
]
47+
],
48+
"outputs": [],
49+
"execution_count": null
5050
},
5151
{
5252
"cell_type": "markdown",
@@ -58,17 +58,17 @@
5858
},
5959
{
6060
"cell_type": "code",
61-
"execution_count": 3,
6261
"metadata": {},
63-
"outputs": [],
6462
"source": [
6563
"# load the model module\n",
6664
"model_module = amici.import_model_module(model_name, model_dir)\n",
6765
"# instantiate model\n",
6866
"model = model_module.get_model()\n",
6967
"# instantiate solver\n",
7068
"solver = model.create_solver()"
71-
]
69+
],
70+
"outputs": [],
71+
"execution_count": null
7272
},
7373
{
7474
"cell_type": "markdown",
@@ -77,10 +77,10 @@
7777
},
7878
{
7979
"cell_type": "code",
80-
"execution_count": 4,
8180
"metadata": {},
81+
"source": "model.set_parameter_by_name(\"p1\", 1e-3)",
8282
"outputs": [],
83-
"source": "model.set_parameter_by_name(\"p1\", 1e-3)"
83+
"execution_count": null
8484
},
8585
{
8686
"cell_type": "markdown",
@@ -89,10 +89,10 @@
8989
},
9090
{
9191
"cell_type": "code",
92-
"execution_count": 5,
9392
"metadata": {},
93+
"source": "solver.set_absolute_tolerance(1e-10)",
9494
"outputs": [],
95-
"source": "solver.set_absolute_tolerance(1e-10)"
95+
"execution_count": null
9696
},
9797
{
9898
"cell_type": "markdown",
@@ -104,18 +104,18 @@
104104
{
105105
"cell_type": "markdown",
106106
"metadata": {},
107-
"source": "Model simulations can be executed using the [amici.run_simulation](https://amici.readthedocs.io/en/latest/generated/amici.html#amici.run_simulation) routine. By default, the model does not contain any timepoints for which the model is to be simulated. Here we define a simulation timecourse with two timepoints at `0` and `1` and then run the simulation."
107+
"source": "Model simulations can be executed using the [Model.simulate](https://amici.readthedocs.io/en/latest/generated/amici.amici.html#amici.amici.Model.simulate) method (or, alternatively, [amici.run_simulation](https://amici.readthedocs.io/en/latest/generated/amici.html#amici.run_simulation)). By default, the model does not contain any output timepoints for which the model is to be simulated. Here we define a simulation timecourse with two output timepoints at `0` and `1` and then run the simulation."
108108
},
109109
{
110110
"cell_type": "code",
111-
"execution_count": 6,
112111
"metadata": {},
113-
"outputs": [],
114112
"source": [
115113
"# set timepoints\n",
116114
"model.set_timepoints([0, 1])\n",
117-
"rdata = amici.run_simulation(model, solver)"
118-
]
115+
"rdata = model.simulate(solver=solver)"
116+
],
117+
"outputs": [],
118+
"execution_count": null
119119
},
120120
{
121121
"cell_type": "markdown",
@@ -126,24 +126,12 @@
126126
},
127127
{
128128
"cell_type": "code",
129-
"execution_count": 7,
130-
"metadata": {},
131-
"outputs": [
132-
{
133-
"data": {
134-
"text/plain": [
135-
"array([[0.1 , 0.4 , 0.7 ],\n",
136-
" [0.98208413, 0.51167992, 0.10633388]])"
137-
]
138-
},
139-
"execution_count": 7,
140-
"metadata": {},
141-
"output_type": "execute_result"
142-
}
143-
],
129+
"metadata": {},
144130
"source": [
145131
"rdata.x"
146-
]
132+
],
133+
"outputs": [],
134+
"execution_count": null
147135
},
148136
{
149137
"cell_type": "markdown",
@@ -152,21 +140,22 @@
152140
},
153141
{
154142
"cell_type": "code",
155-
"execution_count": 8,
156-
"metadata": {},
157-
"outputs": [
158-
{
159-
"data": {
160-
"text/plain": [
161-
"('x1', 'x2', 'x3')"
162-
]
163-
},
164-
"execution_count": 8,
165-
"metadata": {},
166-
"output_type": "execute_result"
167-
}
168-
],
169-
"source": "model.get_state_names()"
143+
"metadata": {},
144+
"source": "model.get_state_names()",
145+
"outputs": [],
146+
"execution_count": null
147+
},
148+
{
149+
"metadata": {},
150+
"cell_type": "markdown",
151+
"source": "For convenience, most results stored in `ReturnData` can also be retrieved as [xarray.DataArray](https://docs.xarray.dev/en/stable/index.html) objects that already include the respective row and column names. This can be accessed via the `xr` attribute of `ReturnData`. Here, we access the model state `x` as `DataArray` object to convert it to a `pandas.DataFrame`:"
152+
},
153+
{
154+
"metadata": {},
155+
"cell_type": "code",
156+
"source": "rdata.xr.x.to_pandas()",
157+
"outputs": [],
158+
"execution_count": null
170159
},
171160
{
172161
"cell_type": "markdown",

python/sdist/amici/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def get_model(self) -> amici.Model:
150150
"""Create a model instance."""
151151
...
152152

153-
AmiciModel = Union[amici.Model, amici.ModelPtr]
153+
AmiciModel = amici.Model | amici.ModelPtr
154154
else:
155155
ModelModule = ModuleType
156156

python/sdist/amici/swig_wrappers.py

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
"""Convenience wrappers for the swig interface"""
22

3+
from __future__ import annotations
34
import logging
45
import warnings
56
from typing import Any
7+
from collections.abc import Sequence
8+
import contextlib
69

710
import amici
811
import amici.amici as amici_swig
@@ -12,8 +15,11 @@
1215
AmiciExpDataVector,
1316
AmiciModel,
1417
AmiciSolver,
18+
SensitivityMethod,
19+
SensitivityOrder,
20+
Solver,
1521
)
16-
from . import numpy
22+
from . import numpy, ReturnDataView
1723
from .logging import get_logger
1824

1925
logger = get_logger(__name__, log_level=logging.DEBUG)
@@ -33,7 +39,7 @@ def run_simulation(
3339
model: AmiciModel,
3440
solver: AmiciSolver,
3541
edata: AmiciExpData | None = None,
36-
) -> "numpy.ReturnDataView":
42+
) -> ReturnDataView:
3743
"""
3844
Convenience wrapper around :py:func:`amici.amici.run_simulation`
3945
(generated by swig)
@@ -79,7 +85,7 @@ def run_simulations(
7985
edata_list: AmiciExpDataVector,
8086
failfast: bool = True,
8187
num_threads: int = 1,
82-
) -> list["numpy.ReturnDataView"]:
88+
) -> list[ReturnDataView]:
8389
"""
8490
Convenience wrapper for loops of amici.runAmiciSimulation
8591
@@ -267,5 +273,70 @@ def _ids_and_names_to_rdata(
267273
f"{entity_type.lower()}_{name_or_id.lower()}",
268274
names_or_ids,
269275
)
276+
270277
rdata.state_ids_solver = model.get_state_ids_solver()
271278
rdata.state_names_solver = model.get_state_names_solver()
279+
280+
281+
@contextlib.contextmanager
282+
def _solver_settings(solver, sensi_method=None, sensi_order=None):
283+
"""Context manager to temporarily apply solver settings."""
284+
old_method = old_order = None
285+
286+
if sensi_method is not None:
287+
old_method = solver.get_sensitivity_method()
288+
if isinstance(sensi_method, str):
289+
sensi_method = SensitivityMethod[sensi_method]
290+
solver.set_sensitivity_method(sensi_method)
291+
292+
if sensi_order is not None:
293+
old_order = solver.get_sensitivity_order()
294+
if isinstance(sensi_order, str):
295+
sensi_order = SensitivityOrder[sensi_order]
296+
solver.set_sensitivity_order(sensi_order)
297+
298+
try:
299+
yield solver
300+
finally:
301+
if old_method is not None:
302+
solver.set_sensitivity_method(old_method)
303+
if old_order is not None:
304+
solver.set_sensitivity_order(old_order)
305+
306+
307+
def _Model__simulate(
308+
self: AmiciModel,
309+
*,
310+
solver: Solver | None = None,
311+
edata: AmiciExpData | AmiciExpDataVector | None = None,
312+
failfast: bool = True,
313+
num_threads: int = 1,
314+
sensi_method: SensitivityMethod | str = None,
315+
sensi_order: SensitivityOrder | str = None,
316+
) -> ReturnDataView | list[ReturnDataView]:
317+
"""
318+
For use in `swig/model.i` to avoid code duplication in subclasses.
319+
320+
Keep in sync with `Model.simulate` and `ModelPtr.simulate`.
321+
322+
"""
323+
if solver is None:
324+
solver = self.create_solver()
325+
326+
with _solver_settings(
327+
solver=solver, sensi_method=sensi_method, sensi_order=sensi_order
328+
):
329+
if isinstance(edata, Sequence):
330+
return run_simulations(
331+
model=_get_ptr(self),
332+
solver=_get_ptr(solver),
333+
edata_list=edata,
334+
failfast=failfast,
335+
num_threads=num_threads,
336+
)
337+
338+
return run_simulation(
339+
model=_get_ptr(self),
340+
solver=_get_ptr(solver),
341+
edata=_get_ptr(edata),
342+
)

python/tests/test_sbml_import.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,7 @@ def test_nosensi(tempdir):
118118

119119
model = model_module.get_model()
120120
model.set_timepoints(np.linspace(0, 60, 61))
121-
solver = model.create_solver()
122-
solver.set_sensitivity_order(amici.SensitivityOrder.first)
123-
solver.set_sensitivity_method(amici.SensitivityMethod.forward)
124-
rdata = amici.run_simulation(model, solver)
121+
rdata = model.simulate(sensi_order="first", sensi_method="forward")
125122
assert rdata.status == amici.AMICI_ERROR
126123

127124

@@ -598,15 +595,15 @@ def test_likelihoods(model_test_likelihoods):
598595

599596
# run model once to create an edata
600597

601-
rdata = amici.run_simulation(model, solver)
598+
rdata = model.simulate(solver=solver)
602599
sigmas = rdata["y"].max(axis=0) * 0.05
603600
edata = amici.ExpData(rdata, sigmas, [])
604601
# just make all observables positive since some are logarithmic
605602
while min(edata.get_observed_data()) < 0:
606603
edata = amici.ExpData(rdata, sigmas, [])
607604

608605
# and now run for real and also compute likelihood values
609-
rdata = amici.run_simulations(model, solver, [edata])[0]
606+
rdata = model.simulate(solver=solver, edata=[edata])[0]
610607

611608
# check if the values make overall sense
612609
assert np.isfinite(rdata["llh"])
@@ -1054,8 +1051,7 @@ def test_regression_2700(tempdir):
10541051
model_module = import_model_module(model_name, tempdir)
10551052
model = model_module.get_model()
10561053
model.set_timepoints([0, 1, 2])
1057-
solver = model.create_solver()
1058-
rdata = amici.run_simulation(model, solver)
1054+
rdata = model.simulate()
10591055

10601056
assert np.all(rdata.by_id("pp") == [1, 1, 1])
10611057

@@ -1093,8 +1089,7 @@ def test_heaviside_init_values_and_bool_to_float_conversion(tempdir):
10931089

10941090
model = model_module.get_model()
10951091
model.set_timepoints([0, 1, 2])
1096-
solver = model.create_solver()
1097-
rdata = amici.run_simulation(model, solver)
1092+
rdata = model.simulate()
10981093

10991094
assert np.all(rdata.by_id("a") == np.array([2, 2, 2])), rdata.by_id("a")
11001095
assert np.all(rdata.by_id("b") == np.array([0, 1, 1])), rdata.by_id("b")

swig/amici.i

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,11 +373,12 @@ if sys.platform == 'win32':
373373
// import additional types for typehints
374374
// also import np for use in __repr__ functions
375375
%pythonbegin %{
376-
from typing import TYPE_CHECKING, Iterable, Union
376+
from typing import TYPE_CHECKING, Iterable, Union, overload
377377
from collections.abc import Sequence
378378
import numpy as np
379379
if TYPE_CHECKING:
380380
import numpy
381+
from .numpy import ReturnDataView
381382
%}
382383

383384
%pythoncode %{
@@ -418,7 +419,7 @@ __all__ = [
418419
x
419420
for x in dir(sys.modules[__name__])
420421
if not x.startswith('_')
421-
and x not in {"np", "sys", "os", "numpy", "IntEnum", "enum", "pi", "TYPE_CHECKING", "Iterable", "Sequence", "Path"}
422+
and x not in {"np", "sys", "os", "numpy", "IntEnum", "enum", "pi", "TYPE_CHECKING", "Iterable", "Sequence", "Path", "Union", "overload"}
422423
]
423424

424425
%}

0 commit comments

Comments
 (0)