Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions modules/safe.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def find_class(self, module, name):
if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
import pytorch_lightning.callbacks.model_checkpoint
return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
if module == "pytorch_lightning.callbacks" and name == 'ModelCheckpoint':
import pytorch_lightning.callbacks
return pytorch_lightning.callbacks.ModelCheckpoint
if module == "__builtin__" and name == 'set':
return set

Expand Down Expand Up @@ -153,6 +156,19 @@ def extra(module, name):
)
return None

# Add security global variable handling
try:
# Try importing the ModelCheckpoint from PyTorch Lightning
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
# Add security global variables
torch.serialization.add_safe_globals([ModelCheckpoint])
except ImportError:
# If the import fails, use a string representation
torch.serialization.add_safe_globals(['pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint'])
except AttributeError:
# If the PyTorch version does not support add_safe_globals, ignore the error
pass

return unsafe_torch_load(filename, *args, **kwargs)


Expand Down
Loading