Skip to content

Commit 06af920

Browse files
rohitgr7lexierule
authored andcommitted
Fix lr_find to generate same results on multiple calls (#9704)
1 parent 9e58d8a commit 06af920

File tree

5 files changed

+23
-0
lines changed

5 files changed

+23
-0
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1111
- Fixed gradient unscaling being called too late, causing gradient clipping and gradient norm tracking to be applied incorrectly ([#9606](https://github.com/PyTorchLightning/pytorch-lightning/pull/9606))
1212

1313

14+
- Fixed `lr_find` to generate same results on multiple calls ([#9704](https://github.com/PyTorchLightning/pytorch-lightning/pull/9704))
15+
16+
1417
## [1.4.8] - 2021-09-22
1518

1619
- Fixed error reporting in DDP process reconciliation when processes are launched by an external agent (#9389)

pytorch_lightning/tuner/batch_size_scaling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def __scale_batch_dump_params(trainer: "pl.Trainer") -> None:
101101
trainer.__dumped_params = {
102102
"auto_lr_find": trainer.auto_lr_find,
103103
"current_epoch": trainer.current_epoch,
104+
"global_step": trainer.global_step,
104105
"max_steps": trainer.max_steps,
105106
"weights_summary": trainer.weights_summary,
106107
"logger": trainer.logger,
@@ -128,6 +129,7 @@ def __scale_batch_reset_params(trainer: "pl.Trainer", model: "pl.LightningModule
128129
def __scale_batch_restore_params(trainer: "pl.Trainer") -> None:
129130
trainer.auto_lr_find = trainer.__dumped_params["auto_lr_find"]
130131
trainer.fit_loop.current_epoch = trainer.__dumped_params["current_epoch"]
132+
trainer.fit_loop.global_step = trainer.__dumped_params["global_step"]
131133
trainer.fit_loop.max_steps = trainer.__dumped_params["max_steps"]
132134
trainer.weights_summary = trainer.__dumped_params["weights_summary"]
133135
trainer.logger = trainer.__dumped_params["logger"]

pytorch_lightning/tuner/lr_finder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def __lr_finder_dump_params(trainer, model):
284284
"auto_lr_find": trainer.auto_lr_find,
285285
"callbacks": trainer.callbacks,
286286
"logger": trainer.logger,
287+
"global_step": trainer.global_step,
287288
"max_steps": trainer.max_steps,
288289
"checkpoint_callback": trainer.checkpoint_callback,
289290
"current_epoch": trainer.current_epoch,
@@ -295,6 +296,7 @@ def __lr_finder_restore_params(trainer, model):
295296
trainer.auto_lr_find = trainer.__dumped_params["auto_lr_find"]
296297
trainer.logger = trainer.__dumped_params["logger"]
297298
trainer.callbacks = trainer.__dumped_params["callbacks"]
299+
trainer.fit_loop.global_step = trainer.__dumped_params["global_step"]
298300
trainer.fit_loop.max_steps = trainer.__dumped_params["max_steps"]
299301
trainer.fit_loop.current_epoch = trainer.__dumped_params["current_epoch"]
300302
model.configure_optimizers = trainer.__dumped_params["configure_optimizers"]

tests/tuner/test_lr_finder.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def test_trainer_reset_correctly(tmpdir):
7575
"checkpoint_callback",
7676
"current_epoch",
7777
"logger",
78+
"global_step",
7879
"max_steps",
7980
]
8081
expected = {ca: getattr(trainer, ca) for ca in changed_attributes}
@@ -282,3 +283,17 @@ def training_step_end(self, outputs):
282283
trainer = Trainer(default_root_dir=tmpdir)
283284
num_training = 3
284285
trainer.tuner.lr_find(model=model, num_training=num_training)
286+
287+
288+
def test_multiple_lr_find_calls_gives_same_results(tmpdir):
289+
"""Tests that lr_finder gives same results if called multiple times."""
290+
model = BoringModel()
291+
292+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2)
293+
all_res = [trainer.tuner.lr_find(model).results for _ in range(3)]
294+
295+
assert all(
296+
all_res[0][k] == curr_lr_finder[k] and len(curr_lr_finder[k]) > 10
297+
for curr_lr_finder in all_res[1:]
298+
for k in all_res[0].keys()
299+
)

tests/tuner/test_scale_batch_size.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def test_trainer_reset_correctly(tmpdir):
108108
"limit_train_batches",
109109
"logger",
110110
"max_steps",
111+
"global_step",
111112
"weights_summary",
112113
]
113114
expected = {ca: getattr(trainer, ca) for ca in changed_attributes}

0 commit comments

Comments
 (0)