@@ -24,12 +24,18 @@ def test_no_gradients(optimizer_name):
2424 else :
2525 optimizer = load_optimizer (optimizer_name )(params )
2626
27+ def sphere_loss (x ) -> torch .Tensor :
28+ return (x ** 2 ).sum ()
29+
2730 optimizer .zero_grad ()
28- p1 .grad = torch .zeros (1 , 1 )
29- p2 .grad = None
30- p3 .grad = torch .zeros (1 , 1 )
31- p4 .grad = None
31+ sphere_loss (p1 + p3 ).backward (create_graph = True )
32+ # p1.grad = torch.zeros(1, 1)
33+ # p2.grad = None
34+ # p3.grad = torch.zeros(1, 1)
35+ # p4.grad = None
3236 optimizer .step (lambda : 0.1 ) # for AliG optimizer
37+ if optimizer_name != 'lookahead' :
38+ optimizer .zero_grad (set_to_none = True )
3339
3440
3541@pytest .mark .parametrize ('no_sparse_optimizer' , NO_SPARSE_OPTIMIZERS )
@@ -109,12 +115,17 @@ def test_bf16_gradient(optimizer_name):
109115 if optimizer_name == 'shampoo' :
110116 pytest .skip (f'skip { optimizer_name } ' )
111117
118+ def sphere_loss (x ) -> torch .Tensor :
119+ return (x ** 2 ).sum ()
120+
112121 param = torch .randn (1 , 1 ).bfloat16 ().requires_grad_ (True )
113- param .grad = torch .randn (1 , 1 ).bfloat16 ()
114122
115123 opt = load_optimizer (optimizer = optimizer_name )
116124 optimizer = opt ([param ], num_iterations = 1 ) if optimizer_name == 'ranger21' else opt ([param ])
125+
126+ sphere_loss (param ).backward (create_graph = True )
117127 optimizer .step (lambda : 0.1 )
128+ optimizer .zero_grad (True )
118129
119130
120131def test_sam_no_gradient ():
0 commit comments