99
1010from .meta import ChainMeta , RunMeta , Variable
1111from .npproto .utils import ndarray_to_numpy
12+ from .utils import as_array_from_ragged
1213
1314InferenceData = TypeVar ("InferenceData" )
1415try :
@@ -252,7 +253,15 @@ def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) ->
252253 warmup_sample_stats [svar .name ].append (stats [tune ])
253254 sample_stats [svar .name ].append (stats [~ tune ])
254255
255- kwargs .setdefault ("save_warmup" , True )
256+ if not equalize_chain_lengths :
257+ # Convert ragged arrays to object-dtyped ndarray because NumPy >=1.24.0 no longer does that automatically
258+ warmup_posterior = {k : as_array_from_ragged (v ) for k , v in warmup_posterior .items ()}
259+ warmup_sample_stats = {
260+ k : as_array_from_ragged (v ) for k , v in warmup_sample_stats .items ()
261+ }
262+ posterior = {k : as_array_from_ragged (v ) for k , v in posterior .items ()}
263+ sample_stats = {k : as_array_from_ragged (v ) for k , v in sample_stats .items ()}
264+
256265 idata = from_dict (
257266 warmup_posterior = warmup_posterior ,
258267 warmup_sample_stats = warmup_sample_stats ,
@@ -263,6 +272,7 @@ def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) ->
263272 attrs = self .meta .attributes ,
264273 constant_data = self .constant_data ,
265274 observed_data = self .observed_data ,
275+ save_warmup = True ,
266276 ** kwargs ,
267277 )
268278 return idata
0 commit comments