Skip to content

Commit d3fe385

Browse files
committed
Fix issue in solution to xarray conversion
1 parent f3c4082 commit d3fe385

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

sunode/problem.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -107,34 +107,38 @@ def solution_to_xarray( # type: ignore
107107
solution = solution.view(self.state_dtype)[..., 0]
108108
params = self.extract_params(user_data)
109109

110-
def as_dict(array, prepend=None): # type: ignore
110+
def as_dict(array, dims, prepend=None): # type: ignore
111111
if prepend is None:
112112
prepend = []
113113
dtype = array.dtype
114114
out = {}
115115
for name in dtype.names:
116116
if array[name].dtype == np.float64:
117-
out['_'.join(prepend + [name])] = array[name]
117+
out['_'.join(prepend + [name])] = (tuple(dims[name][1]), array[name])
118118
else:
119-
out.update(as_dict(array[name], prepend + [name]))
119+
out.update(as_dict(array[name], dims[name], prepend + [name]))
120120
return out
121121

122-
data = xr.Dataset()
122+
data = xr.Dataset(coords=self.coords)
123123
data['time'] = ('time', tvals)
124124
# TODO t0?
125125
if unstack_state:
126-
state = as_dict(solution, ['solution'])
126+
state = as_dict(solution, self.state_subset.dims, ['solution'])
127127
for name in state:
128-
assert name not in data
129-
data[name] = ('time', state[name])
128+
if name in data:
129+
raise ValueError(f"Variable {name} is not unique.")
130+
dims, vals = state[name]
131+
data[name] = (('time',) + dims, vals)
130132
else:
131133
data['solution'] = ('time', solution)
132134

133135
if unstack_params:
134-
params = as_dict(params, ['parameters'])
136+
params = as_dict(params, self.params_subset.dims, ['parameters'])
135137
for name in params:
136-
assert name not in data
137-
data[name] = params[name]
138+
if name in data:
139+
raise ValueError(f"Variable {name} is not unique.")
140+
dims, vals = params[name]
141+
data[name] = (dims, vals)
138142
else:
139143
data['parameters'] = params
140144

0 commit comments

Comments
 (0)