@@ -14,13 +14,15 @@ def __init__(self, initial_data):
1414
1515
1616class MLP8bit (torch .nn .Module ):
17- def __init__ (self , dim1 , dim2 , has_fp16_weights = True , threshold = 0.0 ):
17+ def __init__ (self , dim1 , dim2 , has_fp16_weights = True , memory_efficient_backward = False , threshold = 0.0 ):
1818 super (MLP8bit , self ).__init__ ()
1919 self .fc1 = bnb .nn .Linear8bitLt (
20- dim1 , dim2 , has_fp16_weights = has_fp16_weights , threshold = threshold
20+ dim1 , dim2 , has_fp16_weights = has_fp16_weights , memory_efficient_backward = memory_efficient_backward ,
21+ threshold = threshold
2122 )
2223 self .fc2 = bnb .nn .Linear8bitLt (
23- dim2 , dim1 , has_fp16_weights = has_fp16_weights , threshold = threshold
24+ dim2 , dim1 , has_fp16_weights = has_fp16_weights , memory_efficient_backward = memory_efficient_backward ,
25+ threshold = threshold
2426 )
2527
2628 def forward (self , x ):
@@ -451,9 +453,12 @@ def test_linear8bitlt_accumulated_gradient():
451453
452454
453455@pytest .mark .parametrize ("threshold" , values , ids = names )
454- def test_linear8bitlt_no_fp16_weights (threshold ):
456+ @pytest .mark .parametrize ("memory_efficient_backward" , [True , False ])
457+ def test_linear8bitlt_no_fp16_weights (threshold , memory_efficient_backward ):
455458 l1 = (
456- bnb .nn .Linear8bitLt (32 , 64 , threshold = threshold , has_fp16_weights = False )
459+ bnb .nn .Linear8bitLt (
460+ 32 , 64 , threshold = threshold , has_fp16_weights = False , memory_efficient_backward = memory_efficient_backward
461+ )
457462 .cuda ()
458463 .half ()
459464 )
@@ -513,7 +518,9 @@ def test_linear8bitlt_no_fp16_weights(threshold):
513518 assert mlp .fc2 .weight .dtype == torch .int8
514519
515520 mlp = (
516- MLP8bit (32 , 64 , threshold = threshold , has_fp16_weights = False )
521+ MLP8bit (
522+ 32 , 64 , threshold = threshold , has_fp16_weights = False , memory_efficient_backward = memory_efficient_backward
523+ )
517524 .half ()
518525 .to ("cuda" )
519526 )
@@ -532,7 +539,9 @@ def test_linear8bitlt_no_fp16_weights(threshold):
532539 assert mlp .fc2 .weight .device .type == "cuda"
533540
534541 mlp = (
535- MLP8bit (32 , 64 , threshold = threshold , has_fp16_weights = False )
542+ MLP8bit (
543+ 32 , 64 , threshold = threshold , has_fp16_weights = False , memory_efficient_backward = memory_efficient_backward
544+ )
536545 .to (torch .float16 )
537546 .to ("cuda" )
538547 )
@@ -551,6 +560,7 @@ def test_linear8bitlt_no_fp16_weights(threshold):
551560 assert mlp .fc2 .weight .device .type == "cuda"
552561
553562
563+
554564def test_linear8bitlt_fp32_bias ():
555565 # casts model to fp16 -> int8 automatically
556566 l1 = bnb .nn .Linear8bitLt (32 , 64 , has_fp16_weights = False ).cuda ()
0 commit comments