44"""
55
66import os
7+ import warnings
78from abc import ABC
89from copy import deepcopy
910from collections import OrderedDict
@@ -747,6 +748,7 @@ def get_unetr(
747748 decoder_state : Optional [OrderedDict [str , torch .Tensor ]] = None ,
748749 device : Optional [Union [str , torch .device ]] = None ,
749750 out_channels : int = 3 ,
751+ flexible_load_checkpoint : bool = False ,
750752) -> torch .nn .Module :
751753 """Get UNETR model for automatic instance segmentation.
752754
@@ -756,6 +758,8 @@ def get_unetr(
756758 decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
757759 device: The device.
758760 out_channels: The number of output channels.
761+ flexible_load_checkpoint: Whether to allow reinitialization of parameters
762+ which could not be found in the provided decoder state.
759763
760764 Returns:
761765 The UNETR model.
@@ -775,7 +779,18 @@ def get_unetr(
775779 unetr_state_dict = unetr .state_dict ()
776780 for k , v in unetr_state_dict .items ():
777781 if not k .startswith ("encoder" ):
778- unetr_state_dict [k ] = decoder_state [k ]
782+ if flexible_load_checkpoint : # Whether allow reinitalization of params, if not found.
783+ if k in decoder_state : # First check whether the key is available in the provided decoder state.
784+ unetr_state_dict [k ] = decoder_state [k ]
785+ else : # Otherwise, allow it to initialize it.
786+ warnings .warn (f"Could not find '{ k } ' in the pretrained state dict. Hence, we reinitialize it." )
787+ unetr_state_dict [k ] = v
788+
789+ else : # Whether be strict on finding the parameter in the decoder state.
790+ if k not in decoder_state :
791+ raise RuntimeError (f"The parameters for '{ k } ' could not be found." )
792+ unetr_state_dict [k ] = decoder_state [k ]
793+
779794 unetr .load_state_dict (unetr_state_dict )
780795
781796 unetr .to (device )
0 commit comments