Skip to content

Commit a798a56

Browse files
committed
fixes for compatibility with new torch versions #5
1 parent b11f411 commit a798a56

File tree

4 files changed

+5
-5
lines changed

4 files changed

+5
-5
lines changed

nnunetv2/inference/predict_from_raw_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def initialize_from_trained_model_folder(self, model_training_output_dir: str,
8282
for i, f in enumerate(use_folds):
8383
f = int(f) if f != 'all' else f
8484
checkpoint = torch.load(join(model_training_output_dir, f'fold_{f}', checkpoint_name),
85-
map_location=torch.device('cpu'))
85+
map_location=torch.device('cpu'), weights_only=False)
8686
if i == 0:
8787
trainer_name = checkpoint['trainer_name']
8888
configuration_name = checkpoint['init_args']['configuration']

nnunetv2/run/load_pretrained_weights.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ def load_pretrained_weights(network, fname, verbose=False):
1717
1818
"""
1919
if dist.is_initialized():
20-
saved_model = torch.load(fname, map_location=torch.device('cuda', dist.get_rank()))
20+
saved_model = torch.load(fname, map_location=torch.device('cuda', dist.get_rank()), weights_only=False)
2121
else:
22-
saved_model = torch.load(fname)
22+
saved_model = torch.load(fname, weight_only=False)
2323
pretrained_dict = saved_model['network_weights']
2424

2525
skip_strings_in_pretrained = [

nnunetv2/training/lr_scheduler/polylr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def __init__(self, optimizer, initial_lr: float, max_steps: int, exponent: float
88
self.max_steps = max_steps
99
self.exponent = exponent
1010
self.ctr = 0
11-
super().__init__(optimizer, current_step if current_step is not None else -1, False)
11+
super().__init__(optimizer, current_step if current_step is not None else -1)
1212

1313
def step(self, current_step=None):
1414
if current_step is None or current_step == -1:

nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1181,7 +1181,7 @@ def load_checkpoint(self, filename_or_checkpoint: Union[dict, str]) -> None:
11811181
self.initialize()
11821182

11831183
if isinstance(filename_or_checkpoint, str):
1184-
checkpoint = torch.load(filename_or_checkpoint, map_location=self.device)
1184+
checkpoint = torch.load(filename_or_checkpoint, map_location=self.device, weights_only=False)
11851185
# if state dict comes from nn.DataParallel but we use non-parallel model here then the state dict keys do not
11861186
# match. Use heuristic to make it match
11871187
new_state_dict = {}

0 commit comments

Comments
 (0)