Skip to content

Commit d5d77c4

Browse files
authored
ReturnData fields as xarray.DataArray (#2916)
Make relevant `ReturnData` fields available as `xarray.DataArray`. This includes the identifiers and is often more convenient than the plain arrays, allows for easy subselection and plotting of the results, and conversion to DataFrames. Closes #2170.
1 parent 5e0acb2 commit d5d77c4

File tree

5 files changed

+168
-5
lines changed

5 files changed

+168
-5
lines changed

CHANGELOG.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ See also our [versioning policy](https://amici.readthedocs.io/en/latest/versioni
66

77
### v1.0.0 (unreleased)
88

9-
BREAKING CHANGES
9+
**BREAKING CHANGES**
1010

1111
* `ReturnDataView.posteq_numsteps` and `ReturnDataView.posteq_numsteps` now
1212
return a one-dimensional array of shape `(num_timepoints,)` instead of a
@@ -29,6 +29,15 @@ BREAKING CHANGES
2929
* The `force_compile` argument to `import_petab_problem` has been removed.
3030
See the `compile_` argument.
3131

32+
**Features**
33+
34+
* Many relevant `ReturnData` fields are now available as `xarray.DataArray`
35+
via `ReturnData.xr.{x,y,w,x0,sx,...}`.
36+
`DataArray`s include the identifiers and are often more convenient than the
37+
plain numpy arrays. This allows for easy subselection and plotting of the
38+
results, and conversion to DataFrames.
39+
40+
3241
## v0.X Series
3342

3443
### v0.34.1 (2025-08-25)

doc/examples/getting_started_extended/GettingStartedExtended.ipynb

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,25 @@
877877
"print(f\"{rdata.by_id('x2')=}\")"
878878
]
879879
},
880+
{
881+
"metadata": {},
882+
"cell_type": "markdown",
883+
"source": "Alternatively, those data can be accessed through `ReturnData.xr.*` as [xarray.DataArray](https://docs.xarray.dev/en/stable/index.html) objects, that contain additional metadata such as timepoints and identifiers. This allows for more convenient indexing and plotting of the results."
884+
},
885+
{
886+
"metadata": {},
887+
"cell_type": "code",
888+
"source": "rdata.xr.x",
889+
"outputs": [],
890+
"execution_count": null
891+
},
892+
{
893+
"metadata": {},
894+
"cell_type": "code",
895+
"source": "rdata.xr.x.to_pandas()",
896+
"outputs": [],
897+
"execution_count": null
898+
},
880899
{
881900
"cell_type": "markdown",
882901
"metadata": {},

python/sdist/amici/numpy.py

Lines changed: 126 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
This module provides views on C++ objects for efficient access.
55
"""
66

7+
from __future__ import annotations
78
import collections
89
import copy
910
import itertools
10-
from typing import Literal, Union
11+
from typing import Literal
1112
from collections.abc import Iterator
1213
from numbers import Number
1314
import amici
@@ -22,10 +23,131 @@
2223
ReturnDataPtr,
2324
SteadyStateStatus,
2425
)
26+
import xarray as xr
27+
28+
29+
__all__ = [
30+
"ReturnDataView",
31+
"ExpDataView",
32+
"evaluate",
33+
]
2534

2635
StrOrExpr = str | sp.Expr
2736

2837

38+
class XArrayFactory:
39+
"""
40+
Factory class to create xarray DataArrays for fields of a
41+
SwigPtrView instance.
42+
43+
Currently, only ReturnDataView is supported.
44+
"""
45+
46+
def __init__(self, svp: SwigPtrView):
47+
"""
48+
Constructor
49+
50+
:param svp: SwigPtrView instance to create DataArrays from.
51+
"""
52+
self._svp = svp
53+
54+
def __getattr__(self, name: str) -> xr.DataArray:
55+
"""
56+
Create xarray DataArray for field name
57+
58+
:param name: field name
59+
60+
:returns: xarray DataArray
61+
"""
62+
data = getattr(self._svp, name)
63+
if data is None:
64+
return xr.DataArray(name=name)
65+
66+
dims = None
67+
68+
match name:
69+
case "x":
70+
coords = {
71+
"time": self._svp.ts,
72+
"state": list(self._svp.state_ids),
73+
}
74+
case "x0" | "x_ss":
75+
coords = {
76+
"state": list(self._svp.state_ids),
77+
}
78+
case "xdot":
79+
coords = {
80+
"state": list(self._svp.state_ids_solver),
81+
}
82+
case "y" | "sigmay":
83+
coords = {
84+
"time": self._svp.ts,
85+
"observable": list(self._svp.observable_ids),
86+
}
87+
case "sy" | "ssigmay":
88+
coords = {
89+
"time": self._svp.ts,
90+
"parameter": [
91+
self._svp.parameter_ids[i] for i in self._svp.plist
92+
],
93+
"observable": list(self._svp.observable_ids),
94+
}
95+
case "w":
96+
coords = {
97+
"time": self._svp.ts,
98+
"expression": list(self._svp.expression_ids),
99+
}
100+
case "sx0":
101+
coords = {
102+
"parameter": [
103+
self._svp.parameter_ids[i] for i in self._svp.plist
104+
],
105+
"state": list(self._svp.state_ids),
106+
}
107+
case "sx":
108+
coords = {
109+
"time": self._svp.ts,
110+
"parameter": [
111+
self._svp.parameter_ids[i] for i in self._svp.plist
112+
],
113+
"state": list(self._svp.state_ids),
114+
}
115+
dims = ("time", "parameter", "state")
116+
case "sllh":
117+
coords = {
118+
"parameter": [
119+
self._svp.parameter_ids[i] for i in self._svp.plist
120+
]
121+
}
122+
case "FIM":
123+
coords = {
124+
"parameter1": [
125+
self._svp.parameter_ids[i] for i in self._svp.plist
126+
],
127+
"parameter2": [
128+
self._svp.parameter_ids[i] for i in self._svp.plist
129+
],
130+
}
131+
case "J":
132+
coords = {
133+
"state1": list(self._svp.state_ids_solver),
134+
"state2": list(self._svp.state_ids_solver),
135+
}
136+
case _:
137+
dims = tuple(f"dim_{i}" for i in range(data.ndim))
138+
coords = {
139+
f"dim_{i}": np.arange(dim)
140+
for i, dim in enumerate(data.shape)
141+
}
142+
arr = xr.DataArray(
143+
data,
144+
dims=dims,
145+
coords=coords,
146+
name=name,
147+
)
148+
return arr
149+
150+
29151
class SwigPtrView(collections.abc.Mapping):
30152
"""
31153
Interface class to expose ``std::vector<double>`` and scalar members of
@@ -104,6 +226,7 @@ def __init__(self, swigptr):
104226
"""
105227
self._swigptr = swigptr
106228
self._cache = {}
229+
107230
super().__init__()
108231

109232
def __len__(self) -> int:
@@ -310,6 +433,7 @@ def __init__(self, rdata: ReturnDataPtr | ReturnData):
310433
"numerrtestfailsB": [rdata.nt],
311434
"numnonlinsolvconvfailsB": [rdata.nt],
312435
}
436+
self.xr = XArrayFactory(self)
313437
super().__init__(rdata)
314438

315439
def __getitem__(
@@ -461,7 +585,7 @@ def _field_as_numpy(
461585

462586
def _entity_type_from_id(
463587
entity_id: str,
464-
rdata: Union[amici.ReturnData, "amici.ReturnDataView"] = None,
588+
rdata: amici.ReturnData | amici.ReturnDataView = None,
465589
model: amici.Model = None,
466590
) -> Literal["x", "y", "w", "p", "k"]:
467591
"""Guess the type of some entity by its ID."""

python/sdist/pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ dependencies = [
2929
"toposort",
3030
"setuptools>=48",
3131
"mpmath",
32-
"swig"
32+
"swig",
33+
"xarray>=2025.01.0",
3334
]
3435
license = "BSD-3-Clause"
3536
authors = [

python/tests/test_swig_interface.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numbers
88
from math import nan
99
import pytest
10+
import xarray
1011

1112
import amici
1213
import numpy as np
@@ -531,6 +532,7 @@ def test_rdataview(sbml_example_presimulation_module):
531532
"""Test some SwigPtrView functionality via ReturnDataView."""
532533
model_module = sbml_example_presimulation_module
533534
model = model_module.getModel()
535+
model.setTimepoints([1, 2, 3])
534536
rdata = amici.runAmiciSimulation(model, model.getSolver())
535537
assert isinstance(rdata, amici.ReturnDataView)
536538

@@ -547,11 +549,19 @@ def test_rdataview(sbml_example_presimulation_module):
547549

548550
assert not hasattr(rdata, "nonexisting_attribute")
549551
assert "x" in rdata
550-
assert rdata.x == rdata["x"]
552+
assert (rdata.x == rdata["x"]).all()
551553

552554
# field names are included by dir()
553555
assert "x" in dir(rdata)
554556

557+
# Test xarray conversion
558+
xr_x = rdata.xr.x
559+
assert isinstance(xr_x, xarray.DataArray)
560+
assert (rdata.x == xr_x).all()
561+
assert xr_x.dims == ("time", "state")
562+
assert (xr_x.coords["time"].data == rdata.ts).all()
563+
assert (xr_x.coords["state"].data == model.getStateIds()).all()
564+
555565

556566
def test_python_exceptions(sbml_example_presimulation_module):
557567
"""Test that C++ exceptions are correctly caught and re-raised in Python."""

0 commit comments

Comments
 (0)