diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py index 89aa163bf062..5520efe295ad 100644 --- a/nemo/core/classes/modelPT.py +++ b/nemo/core/classes/modelPT.py @@ -1459,7 +1459,7 @@ def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: st if isinstance(cfg.init_from_ptl_ckpt, str): # Restore checkpoint ckpt_path = cfg.pop('init_from_ptl_ckpt') - ckpt = torch.load(ckpt_path, map_location=map_location) + ckpt = torch.load(ckpt_path, map_location=map_location, weights_only=True) # Restore checkpoint into current model self.load_state_dict(ckpt['state_dict'], strict=False) @@ -1473,7 +1473,7 @@ def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: st for model_load_cfg in model_load_dict.values(): ckpt_path = model_load_cfg.path # Restore model - ckpt = torch.load(ckpt_path, map_location=map_location) + ckpt = torch.load(ckpt_path, map_location=map_location, weights_only=True) include = model_load_cfg.pop('include', [""]) exclude = model_load_cfg.pop('exclude', [])