From fd0893a16629f511a9e92a0efdeccb851b49b7fa Mon Sep 17 00:00:00 2001 From: hujiayucc Date: Fri, 12 Sep 2025 19:17:23 +0800 Subject: [PATCH] fix(security): Update the secure loading process of ModelCheckpoint --- modules/safe.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/modules/safe.py b/modules/safe.py index af019ffd980..6974f52e8c3 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -57,6 +57,9 @@ def find_class(self, module, name): if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint': import pytorch_lightning.callbacks.model_checkpoint return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint + if module == "pytorch_lightning.callbacks" and name == 'ModelCheckpoint': + import pytorch_lightning.callbacks + return pytorch_lightning.callbacks.ModelCheckpoint if module == "__builtin__" and name == 'set': return set @@ -153,6 +156,19 @@ def extra(module, name): ) return None + # Add security global variable handling + try: + # Try importing the ModelCheckpoint from PyTorch Lightning + from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint + # Add security global variables + torch.serialization.add_safe_globals([ModelCheckpoint]) + except ImportError: + # If the import fails, use a string representation + torch.serialization.add_safe_globals(['pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint']) + except AttributeError: + # If the PyTorch version does not support add_safe_globals, ignore the error + pass + return unsafe_torch_load(filename, *args, **kwargs)