Skip to content

Commit eed3b69

Browse files
committed
fix impl for exponential spacing
1 parent 6a09f27 commit eed3b69

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

src/lightning/pytorch/tuner/lr_finder.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,10 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]
190190
191191
"""
192192
losses = torch.tensor(self.results["loss"][skip_begin:-skip_end])
193-
losses = losses[torch.isfinite(losses)]
193+
lrs = torch.tensor(self.results["lr"][skip_begin:-skip_end])
194+
is_finite = torch.isfinite(losses)
195+
losses = losses[is_finite]
196+
lrs = lrs[is_finite]
194197

195198
if len(losses) < 2:
196199
# computing torch.gradient requires at least 2 points
@@ -201,12 +204,12 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]
201204
self._optimal_idx = None
202205
return None
203206

204-
# TODO: When computing the argmin here, and some losses are non-finite, the expected indices could be
205-
# incorrectly shifted by an offset
206-
gradients = torch.gradient(losses)[0] # Unpack the tuple
207+
gradients = torch.gradient(losses, spacing=[lrs])[0] # Compute the gradient of losses w.r.t. learning rates
207208
min_grad = torch.argmin(gradients).item()
208-
209-
self._optimal_idx = min_grad + skip_begin
209+
all_losses_idx = torch.arange(len(self.results["loss"]))
210+
idx_non_skipped = all_losses_idx[skip_begin:-skip_end]
211+
idx_finite = idx_non_skipped[is_finite]
212+
self._optimal_idx = idx_finite[min_grad].item()
210213
return self.results["lr"][self._optimal_idx]
211214

212215

0 commit comments

Comments
 (0)