@@ -332,12 +332,13 @@ def test_linear8bitlt_inference(threshold):
332332def test_linear8bitlt_accumulated_gradient ():
333333 l1 = torch .nn .Sequential (* [bnb .nn .Linear8bitLt (32 , 32 ).cuda ().half () for i in range (2 )])
334334 l2 = torch .nn .Sequential (* [torch .nn .Linear (32 , 32 ).cuda ().half () for i in range (2 )])
335- l2 [0 ].weight = torch .nn .Parameter (l1 [0 ].weight .clone ())
336- l2 [0 ].bias = torch .nn .Parameter (l1 [0 ].bias .clone ())
337- l2 [1 ].weight = torch .nn .Parameter (l1 [1 ].weight .clone ())
338- l2 [1 ].bias = torch .nn .Parameter (l1 [1 ].bias .clone ())
339- opt1 = bnb .optim .Adam8bit (l1 .parameters (), lr = 0.001 )
340- opt2 = bnb .optim .Adam8bit (l2 .parameters (), lr = 0.001 )
335+ l1 [0 ].weight .data .copy_ (l2 [0 ].weight .data )
336+ l1 [1 ].weight .data .copy_ (l2 [1 ].weight .data )
337+ l1 [0 ].bias .data .copy_ (l2 [0 ].bias .data )
338+ l1 [1 ].bias .data .copy_ (l2 [1 ].bias .data )
339+
340+ opt1 = bnb .optim .Adam32bit (l1 .parameters (), lr = 0.001 )
341+ opt2 = bnb .optim .Adam32bit (l2 .parameters (), lr = 0.001 )
341342
342343 acc_steps = 10
343344
@@ -353,7 +354,6 @@ def test_linear8bitlt_accumulated_gradient():
353354 assert l1 [0 ].state .CxB is not None
354355 assert l1 [1 ].state .CxB is not None
355356
356- print (i )
357357 if i > 0 and i % acc_steps == 0 :
358358 opt1 .step ()
359359 opt1 .zero_grad (True )
@@ -368,9 +368,11 @@ def test_linear8bitlt_accumulated_gradient():
368368 # we do this copy because otherwise we have small divergences over time that add up
369369 l1 [0 ].weight .data .copy_ (l2 [0 ].weight .data )
370370 l1 [1 ].weight .data .copy_ (l2 [1 ].weight .data )
371+ l1 [0 ].bias .data .copy_ (l2 [0 ].bias .data )
372+ l1 [1 ].bias .data .copy_ (l2 [1 ].bias .data )
371373 else :
372- torch .testing .assert_close (l1 [0 ].weight .grad , l2 [0 ].weight .grad )
373- torch .testing .assert_close (l1 [1 ].weight .grad , l2 [1 ].weight .grad )
374+ torch .testing .assert_close (l1 [0 ].weight .grad , l2 [0 ].weight .grad , atol = 1e-3 , rtol = 1e-3 )
375+ torch .testing .assert_close (l1 [1 ].weight .grad , l2 [1 ].weight .grad , atol = 1e-3 , rtol = 1e-3 )
374376
375377
376378@pytest .mark .parametrize ("threshold" , [0.0 , 2.0 ])
0 commit comments