Skip to content

Commit d930618

Browse files
authored
Minor updates to UNETR model function (#843)
Updates to unetr model checkpoint loading
1 parent 552dc55 commit d930618

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

micro_sam/instance_segmentation.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
import os
7+
import warnings
78
from abc import ABC
89
from copy import deepcopy
910
from collections import OrderedDict
@@ -747,6 +748,7 @@ def get_unetr(
747748
decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None,
748749
device: Optional[Union[str, torch.device]] = None,
749750
out_channels: int = 3,
751+
flexible_load_checkpoint: bool = False,
750752
) -> torch.nn.Module:
751753
"""Get UNETR model for automatic instance segmentation.
752754
@@ -756,6 +758,8 @@ def get_unetr(
756758
decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
757759
device: The device.
758760
out_channels: The number of output channels.
761+
flexible_load_checkpoint: Whether to allow reinitialization of parameters
762+
which could not be found in the provided decoder state.
759763
760764
Returns:
761765
The UNETR model.
@@ -775,7 +779,18 @@ def get_unetr(
775779
unetr_state_dict = unetr.state_dict()
776780
for k, v in unetr_state_dict.items():
777781
if not k.startswith("encoder"):
778-
unetr_state_dict[k] = decoder_state[k]
782+
if flexible_load_checkpoint: # Whether allow reinitalization of params, if not found.
783+
if k in decoder_state: # First check whether the key is available in the provided decoder state.
784+
unetr_state_dict[k] = decoder_state[k]
785+
else: # Otherwise, allow it to initialize it.
786+
warnings.warn(f"Could not find '{k}' in the pretrained state dict. Hence, we reinitialize it.")
787+
unetr_state_dict[k] = v
788+
789+
else: # Whether be strict on finding the parameter in the decoder state.
790+
if k not in decoder_state:
791+
raise RuntimeError(f"The parameters for '{k}' could not be found.")
792+
unetr_state_dict[k] = decoder_state[k]
793+
779794
unetr.load_state_dict(unetr_state_dict)
780795

781796
unetr.to(device)

0 commit comments

Comments
 (0)