diff --git a/modules/safe.py b/modules/safe.py index af019ffd980..6974f52e8c3 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -57,6 +57,9 @@ def find_class(self, module, name): if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint': import pytorch_lightning.callbacks.model_checkpoint return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint + if module == "pytorch_lightning.callbacks" and name == 'ModelCheckpoint': + import pytorch_lightning.callbacks + return pytorch_lightning.callbacks.ModelCheckpoint if module == "__builtin__" and name == 'set': return set @@ -153,6 +156,19 @@ def extra(module, name): ) return None + # Add security global variable handling + try: + # Try importing the ModelCheckpoint from PyTorch Lightning + from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint + # Add security global variables + torch.serialization.add_safe_globals([ModelCheckpoint]) + except ImportError: + # If the import fails, use a string representation + torch.serialization.add_safe_globals(['pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint']) + except AttributeError: + # If the PyTorch version does not support add_safe_globals, ignore the error + pass + return unsafe_torch_load(filename, *args, **kwargs)