@@ -102,12 +102,13 @@ def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901, PLR0912, PLR0915
102102 model_path = path / Dumper .torch_models / key
103103 model_path .mkdir (parents = True , exist_ok = True )
104104 try :
105- torch .save (val .state_dict (), model_path / "model.pt" )
106- # Save class info for loading
105+ torch .save (val ._model .state_dict (), model_path / "model.pt" )
106+ vocab_path = path / Dumper .torch_models / "vocab.json"
107+ with vocab_path .open ("w" ) as f :
108+ json .dump (obj ._vocab , f )
107109 class_info = {
108110 "module" : val .__class__ .__module__ ,
109111 "name" : val .__class__ .__name__ ,
110- "is_textcnn" : isinstance (val , TextCNN )
111112 }
112113 with (model_path / "class_info.json" ).open ("w" ) as f :
113114 json .dump (class_info , f )
@@ -254,25 +255,15 @@ def load( # noqa: PLR0912, C901, PLR0915
254255 try :
255256 with (model_dir / "class_info.json" ).open ("r" ) as f :
256257 class_info = json .load (f )
258+ vocab_path = path / Dumper .torch_models / "vocab.json"
259+ with vocab_path .open ("r" ) as f :
260+ obj ._vocab = json .load (f )
257261
258262 module = __import__ (class_info ["module" ], fromlist = [class_info ["name" ]])
259263 model_class = getattr (module , class_info ["name" ])
260264
261265 # Create model instance
262- if class_info .get ("is_textcnn" ):
263- # For TextCNN, we need to get the parameters from the parent CNNScorer
264- model = model_class (
265- vocab_size = len (obj ._vocab ) if hasattr (obj , "_vocab" ) and obj ._vocab else 0 ,
266- n_classes = obj ._n_classes if hasattr (obj , "_n_classes" ) else 0 ,
267- embed_dim = obj .embed_dim if hasattr (obj , "embed_dim" ) else 128 ,
268- kernel_sizes = obj .kernel_sizes if hasattr (obj , "kernel_sizes" ) else [3 , 4 , 5 ],
269- num_filters = obj .num_filters if hasattr (obj , "num_filters" ) else 100 ,
270- dropout = obj .dropout if hasattr (obj , "dropout" ) else 0.1 ,
271- padding_idx = obj ._pad_idx if hasattr (obj , "_pad_idx" ) else 0
272- )
273- else :
274- # For other torch models, create with default parameters
275- model = model_class ()
266+ model = model_class ()
276267
277268 # Load state dict
278269 model .load_state_dict (torch .load (model_dir / "model.pt" ))
0 commit comments