diff --git a/fireredasr/models/fireredasr.py b/fireredasr/models/fireredasr.py index 9cb4e33..ba813ad 100644 --- a/fireredasr/models/fireredasr.py +++ b/fireredasr/models/fireredasr.py @@ -107,7 +107,7 @@ def transcribe(self, batch_uttid, batch_wav_path, args={}): def load_fireredasr_aed_model(model_path): - package = torch.load(model_path, map_location=lambda storage, loc: storage) + package = torch.load(model_path, weights_only=False, map_location=lambda storage, loc: storage) print("model args:", package["args"]) model = FireRedAsrAed.from_args(package["args"]) model.load_state_dict(package["model_state_dict"], strict=True) @@ -115,7 +115,7 @@ def load_fireredasr_aed_model(model_path): def load_firered_llm_model_and_tokenizer(model_path, encoder_path, llm_dir): - package = torch.load(model_path, map_location=lambda storage, loc: storage) + package = torch.load(model_path, weights_only=False, map_location=lambda storage, loc: storage) package["args"].encoder_path = encoder_path package["args"].llm_dir = llm_dir print("model args:", package["args"])