@@ -516,7 +516,10 @@ def test_linear_kbit_fp32_bias(module):
516516modules .append (bnb .nn .LinearNF4 )
517517modules .append (lambda d1 , d2 : bnb .nn .LinearFP4 (d1 , d2 , compress_statistics = True ))
518518modules .append (lambda d1 , d2 : bnb .nn .LinearNF4 (d1 , d2 , compress_statistics = True ))
519- names = ['Int8Lt' , '4bit' , 'FP4' , 'NF4' , 'FP4+C' , 'NF4+C' ]
519+ modules .append (lambda d1 , d2 : bnb .nn .LinearFP4 (d1 , d2 , compute_dtype = torch .float32 ))
520+ modules .append (lambda d1 , d2 : bnb .nn .LinearFP4 (d1 , d2 , compute_dtype = torch .float16 ))
521+ modules .append (lambda d1 , d2 : bnb .nn .LinearFP4 (d1 , d2 , compute_dtype = torch .bfloat16 ))
522+ names = ['Int8Lt' , '4bit' , 'FP4' , 'NF4' , 'FP4+C' , 'NF4+C' , 'NF4+fp32' , 'NF4+fp16' , 'NF4+bf16' ]
520523@pytest .mark .skipif (not torch .cuda .is_available (), reason = "this test requires a GPU" )
521524@pytest .mark .parametrize ("module" , modules , ids = names )
522525def test_kbit_backprop (module ):
@@ -563,10 +566,10 @@ def test_kbit_backprop(module):
563566 relerrs2 .append (relerr2 .mean ().item ())
564567
565568 if isinstance (module , bnb .nn .Linear8bitLt ):
566- torch . testing . assert_close (grad1 , grad2 , atol = 0.008 , rtol = 0.05 )
569+ assert_all_approx_close (grad1 , grad2 , atol = 0.008 , rtol = 0.05 , count = 1 )
567570 torch .testing .assert_close (bgrad1 , bgrad2 , atol = 0.008 , rtol = 0.05 )
568571 else :
569- torch . testing . assert_close (grad1 , grad2 , atol = 0.015 , rtol = 0.05 )
572+ assert_all_approx_close (grad1 , grad2 , atol = 0.015 , rtol = 0.05 , count = 1 )
570573 torch .testing .assert_close (bgrad1 , bgrad2 , atol = 0.02 , rtol = 0.05 )
571574 ref .zero_grad ()
572575 kbit .zero_grad ()
@@ -608,9 +611,33 @@ def test_fp8linear():
608611 assert graderr < 0.00002
609612 assert bgraderr < 0.00002
610613
611-
612-
613-
614+ def test_4bit_warnings ():
615+ dim1 = 64
616+
617+ with pytest .warns (UserWarning , match = r'inference or training' ):
618+ net = nn .Sequential (* [bnb .nn .Linear4bit (dim1 , dim1 , compute_dtype = torch .float32 ) for i in range (10 )])
619+ net = net .cuda ()
620+ inp = torch .rand (10 , dim1 ).cuda ().half ()
621+ net (inp )
622+ with pytest .warns (UserWarning , match = r'inference.' ):
623+ net = nn .Sequential (* [bnb .nn .Linear4bit (dim1 , dim1 , compute_dtype = torch .float32 ) for i in range (10 )])
624+ net = net .cuda ()
625+ inp = torch .rand (1 , dim1 ).cuda ().half ()
626+ net (inp )
627+
628+ with pytest .warns (UserWarning ) as record :
629+
630+ net = nn .Sequential (* [bnb .nn .Linear4bit (dim1 , dim1 , compute_dtype = torch .float32 ) for i in range (10 )])
631+ net = net .cuda ()
632+ inp = torch .rand (10 , dim1 ).cuda ().half ()
633+ net (inp )
634+
635+ net = nn .Sequential (* [bnb .nn .Linear4bit (dim1 , dim1 , compute_dtype = torch .float32 ) for i in range (10 )])
636+ net = net .cuda ()
637+ inp = torch .rand (1 , dim1 ).cuda ().half ()
638+ net (inp )
639+
640+ assert len (record ) == 2
614641
615642
616643
0 commit comments