Skip to content

Commit 46e6a08

Browse files
authored
Add custom pickler to handle generalists (#318)
1 parent ca7e02c commit 46e6a08

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

micro_sam/instance_segmentation.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
import os
8+
import pickle
89
import warnings
910
from abc import ABC
1011
from collections import OrderedDict
@@ -748,7 +749,12 @@ def load_instance_segmentation_with_decoder_from_checkpoint(
748749
InstanceSegmentationWithDecoder
749750
"""
750751
device = util.get_device(device)
751-
state = torch.load(checkpoint, map_location=device)
752+
753+
# over-ride the unpickler with our custom one
754+
custom_pickle = pickle
755+
custom_pickle.Unpickler = util._CustomUnpickler
756+
757+
state = torch.load(checkpoint, map_location=device, pickle_module=custom_pickle)
752758

753759
# Get the predictor.
754760
model_state = state["model_state"]

0 commit comments

Comments
 (0)