Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions python/sdist/amici/sim/sundials/_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ def __getattr__(self, name: str) -> xr.DataArray:
if data is None:
return xr.DataArray(name=name)

if isinstance(data, float):
return xr.DataArray(data, name=name)

if not isinstance(data, np.ndarray):
raise TypeError(
f"Cannot create xarray DataArray for field {name} of type"
f" {type(data)}"
)

dims = None

match name:
Expand All @@ -90,7 +99,8 @@ def __getattr__(self, name: str) -> xr.DataArray:
coords = {
"time": self._svp.ts,
"free_parameter": [
self._svp.parameter_ids[i] for i in self._svp.plist
self._svp.free_parameter_ids[i]
for i in self._svp.plist
],
"observable": list(self._svp.observable_ids),
}
Expand All @@ -111,11 +121,12 @@ def __getattr__(self, name: str) -> xr.DataArray:
coords = {
"time": self._svp.ts,
"free_parameter": [
self._svp.parameter_ids[i] for i in self._svp.plist
self._svp.free_parameter_ids[i]
for i in self._svp.plist
],
"state": list(self._svp.state_ids),
}
dims = ("time", "free_parameter", "state_variable")
dims = ("time", "free_parameter", "state")
case "sllh":
coords = {
"free_parameter": [
Expand Down Expand Up @@ -153,6 +164,15 @@ def __getattr__(self, name: str) -> xr.DataArray:
)
return arr

def __dir__(self):
return sorted(
set(
itertools.chain(
dir(super()), self.__dict__, self._svp._field_names
)
)
)


class SwigPtrView(collections.abc.Mapping):
"""
Expand Down
9 changes: 9 additions & 0 deletions python/tests/test_swig_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,15 @@ def test_rdataview(sbml_example_presimulation_module):
assert (xr_x.coords["time"].data == rdata.ts).all()
assert (xr_x.coords["state"].data == model.get_state_ids()).all()

# test that generating the xarrays does not fail, without checking
# their content any further
for attr in dir(rdata.xr):
if not attr.startswith("_"):
try:
getattr(rdata.xr, attr)
except TypeError as e:
print(str(e))


def test_python_exceptions(sbml_example_presimulation_module):
"""Test that C++ exceptions are correctly caught and re-raised in Python."""
Expand Down
Loading