diff --git a/python/sdist/amici/sim/sundials/_numpy.py b/python/sdist/amici/sim/sundials/_numpy.py index 3e254db175..50a9c5de92 100644 --- a/python/sdist/amici/sim/sundials/_numpy.py +++ b/python/sdist/amici/sim/sundials/_numpy.py @@ -65,6 +65,15 @@ def __getattr__(self, name: str) -> xr.DataArray: if data is None: return xr.DataArray(name=name) + if isinstance(data, float): + return xr.DataArray(data, name=name) + + if not isinstance(data, np.ndarray): + raise TypeError( + f"Cannot create xarray DataArray for field {name} of type" + f" {type(data)}" + ) + dims = None match name: @@ -90,7 +99,8 @@ def __getattr__(self, name: str) -> xr.DataArray: coords = { "time": self._svp.ts, "free_parameter": [ - self._svp.parameter_ids[i] for i in self._svp.plist + self._svp.free_parameter_ids[i] + for i in self._svp.plist ], "observable": list(self._svp.observable_ids), } @@ -111,11 +121,12 @@ def __getattr__(self, name: str) -> xr.DataArray: coords = { "time": self._svp.ts, "free_parameter": [ - self._svp.parameter_ids[i] for i in self._svp.plist + self._svp.free_parameter_ids[i] + for i in self._svp.plist ], "state": list(self._svp.state_ids), } - dims = ("time", "free_parameter", "state_variable") + dims = ("time", "free_parameter", "state") case "sllh": coords = { "free_parameter": [ @@ -153,6 +164,15 @@ def __getattr__(self, name: str) -> xr.DataArray: ) return arr + def __dir__(self): + return sorted( + set( + itertools.chain( + dir(super()), self.__dict__, self._svp._field_names + ) + ) + ) + class SwigPtrView(collections.abc.Mapping): """ diff --git a/python/tests/test_swig_interface.py b/python/tests/test_swig_interface.py index c2823228ab..370f82d9fa 100644 --- a/python/tests/test_swig_interface.py +++ b/python/tests/test_swig_interface.py @@ -594,6 +594,15 @@ def test_rdataview(sbml_example_presimulation_module): assert (xr_x.coords["time"].data == rdata.ts).all() assert (xr_x.coords["state"].data == model.get_state_ids()).all() + # test that generating the xarrays does not fail, without checking + # their content any further + for attr in dir(rdata.xr): + if not attr.startswith("_"): + try: + getattr(rdata.xr, attr) + except TypeError as e: + print(str(e)) + def test_python_exceptions(sbml_example_presimulation_module): """Test that C++ exceptions are correctly caught and re-raised in Python."""