Skip to content

Commit 55e7322

Browse files
Metrics load (#228)
* load from metrics defaults to CPU * load from metrics defaults to CPU * load from metrics defaults to CPU
1 parent c0f3b6b commit 55e7322

File tree

2 files changed

+7
-13
lines changed

2 files changed

+7
-13
lines changed

pytorch_lightning/root_module/root_module.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def val_dataloader(self):
130130
return None
131131

132132
@classmethod
133-
def load_from_metrics(cls, weights_path, tags_csv, on_gpu, map_location=None):
133+
def load_from_metrics(cls, weights_path, tags_csv, on_gpu):
134134
"""
135135
Primary way of loading model from csv weights path
136136
:param weights_path:
@@ -142,13 +142,9 @@ def load_from_metrics(cls, weights_path, tags_csv, on_gpu, map_location=None):
142142
hparams = load_hparams_from_tags_csv(tags_csv)
143143
hparams.__setattr__('on_gpu', on_gpu)
144144

145-
if on_gpu:
146-
if map_location is not None:
147-
checkpoint = torch.load(weights_path, map_location=map_location)
148-
else:
149-
checkpoint = torch.load(weights_path)
150-
else:
151-
checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage)
145+
# load on CPU only to avoid OOM issues
146+
# then its up to user to put back on GPUs
147+
checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage)
152148

153149
# load the state_dict on the model automatically
154150
model = cls(hparams)

tests/test_models.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,8 +1030,7 @@ def test_amp_gpu_ddp_slurm_managed():
10301030
assert trainer.resolve_root_node_address('abc[23-24, 45-40, 40]') == 'abc23'
10311031

10321032
# test model loading with a map_location
1033-
map_location = 'cuda:1'
1034-
pretrained_model = load_model(exp, save_dir, True, map_location)
1033+
pretrained_model = load_model(exp, save_dir, True)
10351034

10361035
# test model preds
10371036
run_prediction(model.test_dataloader, pretrained_model)
@@ -1406,7 +1405,7 @@ def clear_save_dir():
14061405
shutil.move(save_dir, save_dir + f'_{n}')
14071406

14081407

1409-
def load_model(exp, save_dir, on_gpu, map_location=None, module_class=LightningTemplateModel):
1408+
def load_model(exp, save_dir, on_gpu, module_class=LightningTemplateModel):
14101409

14111410
# load trained model
14121411
tags_path = exp.get_data_path(exp.name, exp.version)
@@ -1417,8 +1416,7 @@ def load_model(exp, save_dir, on_gpu, map_location=None, module_class=LightningT
14171416

14181417
trained_model = module_class.load_from_metrics(weights_path=weights_dir,
14191418
tags_csv=tags_path,
1420-
on_gpu=on_gpu,
1421-
map_location=map_location)
1419+
on_gpu=on_gpu)
14221420

14231421
assert trained_model is not None, 'loading model failed'
14241422

0 commit comments

Comments
 (0)