@@ -86,6 +86,7 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
8686 attrs : dict [str , ModuleAttributes ] = vars (obj )
8787 simple_attrs = {}
8888 arrays : dict [str , npt .NDArray [Any ]] = {}
89+ containers = {}
8990
9091 Dumper .make_subdirectories (path , exists_ok )
9192
@@ -96,6 +97,8 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
9697 val .dump (path / Dumper .tags / key )
9798 elif isinstance (val , ModuleSimpleAttributes ):
9899 simple_attrs [key ] = val
100+ elif isinstance (val , dict ):
101+ containers [key ] = val
99102 elif isinstance (val , np .ndarray ):
100103 arrays [key ] = val
101104 elif isinstance (val , Embedder ):
@@ -154,6 +157,9 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
154157 "module" : val .__class__ .__module__ ,
155158 "name" : val .__class__ .__name__ ,
156159 }
160+ # Save configuration if available
161+ if hasattr (val , 'get_config' ):
162+ class_info ['config' ] = val .get_config ()
157163 with (model_path / "class_info.json" ).open ("w" ) as f :
158164 json .dump (class_info , f )
159165 except Exception as e :
@@ -174,6 +180,9 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
174180 with (path / Dumper .simple_attrs ).open ("w" ) as file :
175181 json .dump (simple_attrs , file , ensure_ascii = False , indent = 4 )
176182
183+ with (path / Dumper .containers / "containers.json" ).open ("w" ) as f :
184+ json .dump (containers , f , ensure_ascii = False , indent = 4 )
185+
177186 np .savez (path / Dumper .arrays , allow_pickle = False , ** arrays )
178187
179188 @staticmethod
@@ -275,32 +284,31 @@ def load( # noqa: C901, PLR0912, PLR0915
275284 try :
276285 with (model_dir / "class_info.json" ).open ("r" ) as f :
277286 class_info = json .load (f )
278-
279287 module = __import__ (class_info ["module" ], fromlist = [class_info ["name" ]])
280288 model_class = getattr (module , class_info ["name" ])
281-
282- # Create model instance
283- model = model_class ()
284-
285- # Load state dict
289+ config = class_info .get ('config' , {})
290+ # Initialize model with config if available
291+ model = model_class (** config )
286292 model .load_state_dict (torch .load (model_dir / "model.pt" ))
287293 model .eval ()
288294 torch_models [model_dir .name ] = model
289- except Exception as e : # noqa: PERF203
290- msg = f"Error loading torch model { model_dir .name } : { e } "
291- logger .exception (msg )
295+ except Exception as e :
296+ logger .exception (f"Error loading torch model { model_dir .name } : { e } " )
292297 elif child .name == Dumper .containers :
293298 try :
294299 for container_file in child .iterdir ():
300+ print (container_file )
295301 with container_file .open ("r" ) as f :
296- containers [ container_file . stem ] = json .load (f )
302+ containers = json .load (f )
297303 except Exception as e :
298304 msg = f"Error loading containers: { e } "
299305 logger .exception (msg )
300306 else :
301307 msg = f"Found unexpected child { child } "
302308 logger .error (msg )
303309
310+ print (containers )
311+
304312 obj .__dict__ .update (
305313 tags
306314 | simple_attrs
0 commit comments