Skip to content

Commit 7208641

Browse files
matt3opre-commit-ci[bot]tangy5SachidanandAlle
authored
Fix silent fail if network loads wrong weights (#1521)
* Add warning if the imported network does not have the right keys Signed-off-by: Matthias Hadlich <[email protected]> * Update warning Signed-off-by: Matthias Hadlich <[email protected]> * Set load_strict to false for deepedit Signed-off-by: Matthias Hadlich <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Matthias Hadlich <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: tangy5 <[email protected]> Co-authored-by: SACHIDANAND ALLE <[email protected]>
1 parent 43e9e2a commit 7208641

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

monailabel/tasks/infer/basic_infer.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(
5757
output_label_key: str = "pred",
5858
output_json_key: str = "result",
5959
config: Union[None, Dict[str, Any]] = None,
60-
load_strict: bool = False,
60+
load_strict: bool = True,
6161
roi_size=None,
6262
preload=False,
6363
train_mode=False,
@@ -453,6 +453,16 @@ def _get_network(self, device, data):
453453
if path:
454454
checkpoint = torch.load(path, map_location=torch.device(device))
455455
model_state_dict = checkpoint.get(self.model_state_dict, checkpoint)
456+
457+
if set(self.network.state_dict().keys()) != set(checkpoint.keys()):
458+
logger.warning(
459+
f"Checkpoint keys don't match network.state_dict()! Items that exist in only one dict"
460+
f" but not in the other: {set(self.network.state_dict().keys()) ^ set(checkpoint.keys())}"
461+
)
462+
logger.warning(
463+
"The run will now continue unless load_strict is set to True. "
464+
"If loading fails or the network behaves abnormally, please check the loaded weights"
465+
)
456466
network.load_state_dict(model_state_dict, strict=self.load_strict)
457467
else:
458468
network = torch.jit.load(path, map_location=torch.device(device))

sample-apps/radiology/lib/infers/deepedit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(
7171
self.spatial_size = spatial_size
7272
self.target_spacing = target_spacing
7373
self.number_intensity_ch = number_intensity_ch
74+
self.load_strict = False
7475

7576
def pre_transforms(self, data=None):
7677
t = [

0 commit comments

Comments
 (0)