@@ -606,18 +606,21 @@ def configure_optimizers(self):
606606
607607
608608def test_gradient_correctness ():
609- """Test that torch.gradient uses correct spacing parameter ."""
609+ """Test that gradients are computed with log-spaced learning rates in exponential mode ."""
610610 lr_finder = _LRFinder (mode = "exponential" , lr_min = 1e-6 , lr_max = 1e-1 , num_training = 20 )
611611
612- # Synthetic example
613- lrs = torch .linspace (0 , 2 * math .pi , steps = 1000 )
612+ lrs = torch .logspace (- 6 , - 1 , steps = 1000 )
614613 losses = torch .sin (lrs )
615614 lr_finder .results = {"lr" : lrs .tolist (), "loss" : losses .tolist ()}
616615
617- # Test the suggestion method
618616 suggestion = lr_finder .suggestion (skip_begin = 2 , skip_end = 2 )
619617 assert suggestion is not None
620- assert abs (suggestion - math .pi ) < 1e-2 , "Suggestion should be close to pi for this synthetic example"
618+
619+ losses_t = torch .tensor (lr_finder .results ["loss" ][2 :- 2 ])
620+ lrs_t = torch .tensor (lr_finder .results ["lr" ][2 :- 2 ])
621+ gradients = torch .gradient (losses_t , spacing = [torch .log10 (lrs_t )])[0 ]
622+ expected_idx = torch .argmin (gradients ).item () + 2
623+ assert math .isclose (suggestion , lrs [expected_idx ].item ())
621624
622625
623626def test_lr_finder_callback_applies_lr_after_restore (tmp_path ):
@@ -738,15 +741,12 @@ def __init__(self):
738741 lrs_filtered = lrs [is_finite ]
739742
740743 if len (losses_filtered ) >= 2 :
741- # Test that gradient computation works and produces finite results
742- gradients = torch .gradient (losses_filtered , spacing = [ lrs_filtered ] )[0 ]
744+ spacing = [ torch . log10 ( lrs_filtered )] if mode == "exponential" else [ lrs_filtered ]
745+ gradients = torch .gradient (losses_filtered , spacing = spacing )[0 ]
743746 assert torch .isfinite (gradients ).all (), f"Non-finite gradients in { mode } mode"
744747 assert len (gradients ) == len (losses_filtered )
745748
746- # Verify gradients with spacing differ from gradients without spacing
747749 gradients_no_spacing = torch .gradient (losses_filtered )[0 ]
748-
749- # For exponential mode, these should definitely be different, for linear mode, they might be similar
750750 if mode == "exponential" :
751751 assert not torch .allclose (gradients , gradients_no_spacing , rtol = 0.1 ), (
752752 "Gradients should differ significantly in exponential mode when using proper spacing"
0 commit comments