Skip to content

Commit c538116

Browse files
committed
Fix ReturnData.xr.*
Fix generating xarrays in `ReturnData.xr.*`. Broke during renaming 'parameters'. Increase test coverage.
1 parent 8fc33dc commit c538116

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

python/sdist/amici/numpy.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,15 @@ def __getattr__(self, name: str) -> xr.DataArray:
6666
if data is None:
6767
return xr.DataArray(name=name)
6868

69+
if isinstance(data, float):
70+
return xr.DataArray(data, name=name)
71+
72+
if not isinstance(data, np.ndarray):
73+
raise TypeError(
74+
f"Cannot create xarray DataArray for field {name} of type"
75+
f" {type(data)}"
76+
)
77+
6978
dims = None
7079

7180
match name:
@@ -91,7 +100,8 @@ def __getattr__(self, name: str) -> xr.DataArray:
91100
coords = {
92101
"time": self._svp.ts,
93102
"free_parameter": [
94-
self._svp.parameter_ids[i] for i in self._svp.plist
103+
self._svp.free_parameter_ids[i]
104+
for i in self._svp.plist
95105
],
96106
"observable": list(self._svp.observable_ids),
97107
}
@@ -112,11 +122,12 @@ def __getattr__(self, name: str) -> xr.DataArray:
112122
coords = {
113123
"time": self._svp.ts,
114124
"free_parameter": [
115-
self._svp.parameter_ids[i] for i in self._svp.plist
125+
self._svp.free_parameter_ids[i]
126+
for i in self._svp.plist
116127
],
117128
"state": list(self._svp.state_ids),
118129
}
119-
dims = ("time", "free_parameter", "state_variable")
130+
dims = ("time", "free_parameter", "state")
120131
case "sllh":
121132
coords = {
122133
"free_parameter": [
@@ -154,6 +165,15 @@ def __getattr__(self, name: str) -> xr.DataArray:
154165
)
155166
return arr
156167

168+
def __dir__(self):
169+
return sorted(
170+
set(
171+
itertools.chain(
172+
dir(super()), self.__dict__, self._svp._field_names
173+
)
174+
)
175+
)
176+
157177

158178
class SwigPtrView(collections.abc.Mapping):
159179
"""

python/tests/test_swig_interface.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,15 @@ def test_rdataview(sbml_example_presimulation_module):
573573
assert (xr_x.coords["time"].data == rdata.ts).all()
574574
assert (xr_x.coords["state"].data == model.get_state_ids()).all()
575575

576+
# test that generating the xarrays does not fail, without checking
577+
# their content any further
578+
for attr in dir(rdata.xr):
579+
if not attr.startswith("_"):
580+
try:
581+
getattr(rdata.xr, attr)
582+
except TypeError as e:
583+
print(str(e))
584+
576585

577586
def test_python_exceptions(sbml_example_presimulation_module):
578587
"""Test that C++ exceptions are correctly caught and re-raised in Python."""

0 commit comments

Comments
 (0)