@@ -107,34 +107,38 @@ def solution_to_xarray( # type: ignore
107107 solution = solution .view (self .state_dtype )[..., 0 ]
108108 params = self .extract_params (user_data )
109109
110- def as_dict (array , prepend = None ): # type: ignore
110+ def as_dict (array , dims , prepend = None ): # type: ignore
111111 if prepend is None :
112112 prepend = []
113113 dtype = array .dtype
114114 out = {}
115115 for name in dtype .names :
116116 if array [name ].dtype == np .float64 :
117- out ['_' .join (prepend + [name ])] = array [name ]
117+ out ['_' .join (prepend + [name ])] = ( tuple ( dims [ name ][ 1 ]), array [name ])
118118 else :
119- out .update (as_dict (array [name ], prepend + [name ]))
119+ out .update (as_dict (array [name ], dims [ name ], prepend + [name ]))
120120 return out
121121
122- data = xr .Dataset ()
122+ data = xr .Dataset (coords = self . coords )
123123 data ['time' ] = ('time' , tvals )
124124 # TODO t0?
125125 if unstack_state :
126- state = as_dict (solution , ['solution' ])
126+ state = as_dict (solution , self . state_subset . dims , ['solution' ])
127127 for name in state :
128- assert name not in data
129- data [name ] = ('time' , state [name ])
128+ if name in data :
129+ raise ValueError (f"Variable { name } is not unique." )
130+ dims , vals = state [name ]
131+ data [name ] = (('time' ,) + dims , vals )
130132 else :
131133 data ['solution' ] = ('time' , solution )
132134
133135 if unstack_params :
134- params = as_dict (params , ['parameters' ])
136+ params = as_dict (params , self . params_subset . dims , ['parameters' ])
135137 for name in params :
136- assert name not in data
137- data [name ] = params [name ]
138+ if name in data :
139+ raise ValueError (f"Variable { name } is not unique." )
140+ dims , vals = params [name ]
141+ data [name ] = (dims , vals )
138142 else :
139143 data ['parameters' ] = params
140144
0 commit comments