@@ -381,14 +381,21 @@ def save(self, fname: str) -> None:
381
381
raise RuntimeError ("The model hasn't been fit yet, call .fit() first" )
382
382
383
383
@classmethod
384
- def _convert_dims_to_tuple (cls , model_config : Dict ) -> Dict :
384
+ def _model_config_formatting (cls , model_config : Dict ) -> Dict :
385
+ """
386
+ Because of json serialization, model_config values that were originally tuples or numpy are being encoded as lists.
387
+ This function converts them back to tuples and numpy arrays to ensure correct id encoding.
388
+ """
385
389
for key in model_config :
386
- if (
387
- isinstance (model_config [key ], dict )
388
- and "dims" in model_config [key ]
389
- and isinstance (model_config [key ]["dims" ], list )
390
- ):
391
- model_config [key ]["dims" ] = tuple (model_config [key ]["dims" ])
390
+ if isinstance (model_config [key ], dict ):
391
+ for sub_key in model_config [key ]:
392
+ if isinstance (model_config [key ][sub_key ], list ):
393
+ # Check if "dims" key to convert it to tuple
394
+ if sub_key == "dims" :
395
+ model_config [key ][sub_key ] = tuple (model_config [key ][sub_key ])
396
+ # Convert all other lists to numpy arrays
397
+ else :
398
+ model_config [key ][sub_key ] = np .array (model_config [key ][sub_key ])
392
399
return model_config
393
400
394
401
@classmethod
@@ -420,7 +427,7 @@ def load(cls, fname: str):
420
427
filepath = Path (str (fname ))
421
428
idata = az .from_netcdf (filepath )
422
429
# needs to be converted, because json.loads was changing tuple to list
423
- model_config = cls ._convert_dims_to_tuple (json .loads (idata .attrs ["model_config" ]))
430
+ model_config = cls ._model_config_formatting (json .loads (idata .attrs ["model_config" ]))
424
431
model = cls (
425
432
model_config = model_config ,
426
433
sampler_config = json .loads (idata .attrs ["sampler_config" ]),
0 commit comments