Skip to content

Commit 4bbcfa0

Browse files
.fit() returns last not best weights in ddp_spawn (#2565)
* added base tests for tpu * added base tests for tpu * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint
1 parent e1bc208 commit 4bbcfa0

File tree

3 files changed

+55
-7
lines changed

3 files changed

+55
-7
lines changed

pytorch_lightning/trainer/distrib_data_parallel.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ class TrainerDDPMixin(ABC):
189189
num_nodes: int
190190
node_rank: int
191191
tpu_cores: int
192+
testing: bool
192193

193194
@property
194195
@abstractmethod
@@ -555,15 +556,35 @@ def ddp_train(self, process_idx, q, model, is_master=False, proc_offset=0):
555556
# continue training routine
556557
results = self.run_pretrain_routine(model)
557558

559+
# persist info in ddp_spawn
560+
self.__transfer_ddp_spawn_state_on_fit_end(model, q, results)
561+
558562
# clean up memory
559563
torch.cuda.empty_cache()
560564

565+
if self.global_rank == 0 and self.distributed_backend not in ['ddp_spawn', 'ddp_cpu']:
566+
return results
567+
568+
def __transfer_ddp_spawn_state_on_fit_end(self, model, q, results):
569+
if not self.distributed_backend in ['ddp_spawn', 'ddp_cpu']:
570+
return
571+
572+
# track the best model path
573+
best_model_path = None
574+
if self.checkpoint_callback is not None:
575+
best_model_path = self.checkpoint_callback.best_model_path
576+
561577
if self.global_rank == 0 and q is not None:
562-
q.put(self.checkpoint_callback.best_model_path)
578+
rank_zero_warn('cleaning up ddp environment...')
579+
q.put(best_model_path)
563580
q.put(results)
564581

565-
if self.global_rank == 0 and self.distributed_backend != 'ddp_spawn':
566-
return results
582+
# save the last weights
583+
last_path = None
584+
if not self.testing:
585+
last_path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt')
586+
torch.save(model.state_dict(), last_path)
587+
q.put(last_path)
567588

568589
def save_spawn_weights(self, model):
569590
"""
@@ -574,6 +595,7 @@ def save_spawn_weights(self, model):
574595
if self.is_global_zero:
575596
path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt')
576597
self.save_checkpoint(path)
598+
return path
577599

578600
def load_spawn_weights(self, original_model):
579601
"""

pytorch_lightning/trainer/trainer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from pytorch_lightning.utilities import rank_zero_warn, parsing, rank_zero_info, rank_zero_only
3636
import warnings
3737

38-
# warnings to ignore
38+
# warnings to ignore in trainer
3939
warnings.filterwarnings('ignore', message='torch.distributed.reduce_op is deprecated, '
4040
'please use torch.distributed.ReduceOp instead')
4141

@@ -1063,9 +1063,14 @@ def __run_ddp_spawn(self, model, nprocs):
10631063
# restore main state with best weights
10641064
best_path = q.get()
10651065
results = q.get()
1066-
if best_path is not None and len(best_path) > 0:
1067-
self.checkpoint_callback.best_model_path = best_path
1068-
model.load_from_checkpoint(best_path)
1066+
last_path = q.get()
1067+
1068+
# transfer back the best path to the trainer
1069+
self.checkpoint_callback.best_model_path = best_path
1070+
1071+
# load last weights
1072+
if last_path is not None and not self.testing:
1073+
torch.load(last_path, map_location=lambda storage, loc: storage)
10691074

10701075
self.model = model
10711076
return results

tests/models/test_test_loop.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,16 @@ def test_single_gpu_test(tmpdir):
2323
results = trainer.test()
2424
assert 'test_acc' in results
2525

26+
old_weights = model.c_d1.weight.clone().detach().cpu()
27+
2628
results = trainer.test(model)
2729
assert 'test_acc' in results
2830

31+
# make sure weights didn't change
32+
new_weights = model.c_d1.weight.clone().detach().cpu()
33+
34+
assert torch.all(torch.eq(old_weights, new_weights))
35+
2936

3037
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
3138
def test_dp_test(tmpdir):
@@ -45,9 +52,16 @@ def test_dp_test(tmpdir):
4552
results = trainer.test()
4653
assert 'test_acc' in results
4754

55+
old_weights = model.c_d1.weight.clone().detach().cpu()
56+
4857
results = trainer.test(model)
4958
assert 'test_acc' in results
5059

60+
# make sure weights didn't change
61+
new_weights = model.c_d1.weight.clone().detach().cpu()
62+
63+
assert torch.all(torch.eq(old_weights, new_weights))
64+
5165

5266
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
5367
def test_ddp_spawn_test(tmpdir):
@@ -67,5 +81,12 @@ def test_ddp_spawn_test(tmpdir):
6781
results = trainer.test()
6882
assert 'test_acc' in results
6983

84+
old_weights = model.c_d1.weight.clone().detach().cpu()
85+
7086
results = trainer.test(model)
7187
assert 'test_acc' in results
88+
89+
# make sure weights didn't change
90+
new_weights = model.c_d1.weight.clone().detach().cpu()
91+
92+
assert torch.all(torch.eq(old_weights, new_weights))

0 commit comments

Comments
 (0)