66 BOOLEAN_TRIPLES ,
77 TRUE_FALSE ,
88 describe_dtype ,
9+ get_available_devices ,
910 id_formatter ,
1011)
1112
1213TRANSPOSE_VALS = [(False , True ), (False , False )]
1314
1415
16+ @pytest .mark .parametrize ("device" , get_available_devices ())
1517@pytest .mark .parametrize ("dim1" , [40 ], ids = id_formatter ("dim1" ))
1618@pytest .mark .parametrize ("dim2" , [64 , 0 ], ids = id_formatter ("dim2" ))
1719@pytest .mark .parametrize ("dim3" , [32 ], ids = id_formatter ("dim3" ))
2729@pytest .mark .parametrize ("transpose" , TRANSPOSE_VALS , ids = id_formatter ("transpose" ))
2830@pytest .mark .parametrize ("has_fp16_weights" , TRUE_FALSE , ids = id_formatter ("has_fp16_weights" ))
2931@pytest .mark .parametrize ("has_bias" , TRUE_FALSE , ids = id_formatter ("has_bias" ))
30- def test_matmullt (dim1 , dim2 , dim3 , dim4 , funcs , dtype , req_grad , transpose , decomp , has_fp16_weights , has_bias ):
32+ def test_matmullt (
33+ device , dim1 , dim2 , dim3 , dim4 , funcs , dtype , req_grad , transpose , decomp , has_fp16_weights , has_bias
34+ ):
35+ if device != "cuda" and funcs [1 ] == bnb .research .switchback_bnb :
36+ # TODO: Deprecate/remove?
37+ pytest .skip ("switchback_bnb only works on CUDA." )
38+
3139 dimA = (dim2 , dim3 ) if not transpose [0 ] else (dim3 , dim2 )
3240 dimB = (dim3 , dim4 ) if not transpose [1 ] else (dim4 , dim3 )
33- outlier_dim = torch .randint (0 , dimA [1 ], size = (dimA [1 ] // 8 ,), device = "cuda" )
41+ outlier_dim = torch .randint (0 , dimA [1 ], size = (dimA [1 ] // 8 ,), device = device )
3442 if has_bias == False :
3543 req_grad = list (req_grad )
3644 req_grad [2 ] = False
3745
3846 for i in range (3 ):
3947 # normal multiply
4048 if funcs [0 ] in [torch .mm , torch .matmul ]:
41- A = torch .randn (size = dimA , device = "cuda" , requires_grad = req_grad [0 ], dtype = dtype )
49+ A = torch .randn (size = dimA , device = device , requires_grad = req_grad [0 ], dtype = dtype )
4250 if decomp == 6.0 :
4351 with torch .no_grad ():
4452 A [:, outlier_dim ] = 6.0
45- B = torch .randn (size = dimB , device = "cuda" , requires_grad = req_grad [1 ], dtype = dtype )
53+ B = torch .randn (size = dimB , device = device , requires_grad = req_grad [1 ], dtype = dtype )
4654 target = torch .randn (
4755 size = (dim2 , dim4 ),
48- device = "cuda" ,
56+ device = device ,
4957 requires_grad = req_grad [1 ],
5058 dtype = dtype ,
5159 )
5260 bias = None
5361 bias2 = None
5462 if has_bias :
55- bias = torch .randn (dim4 , device = "cuda" , dtype = dtype , requires_grad = req_grad [2 ])
63+ bias = torch .randn (dim4 , device = device , dtype = dtype , requires_grad = req_grad [2 ])
5664 bias2 = bias .clone ()
5765 torch .nn .init .xavier_uniform_ (B )
5866 B2 = B .clone ()
@@ -91,7 +99,8 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
9199 if has_fp16_weights :
92100 if any (req_grad ):
93101 out_bnb .data .copy_ (out_torch )
94- torch .cuda .synchronize ()
102+ if device == "cuda" :
103+ torch .cuda .synchronize ()
95104 loss_bnb = torch .nn .functional .mse_loss (out_bnb , target ).mean ()
96105 loss_bnb .backward ()
97106 gradA1 = A .grad
@@ -135,6 +144,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
135144 torch .testing .assert_close (gradBias1 , gradBias2 )
136145
137146
147+ @pytest .mark .parametrize ("device" , get_available_devices ())
138148@pytest .mark .parametrize ("dim1" , [48 ], ids = id_formatter ("dim1" ))
139149@pytest .mark .parametrize ("dim2" , [64 , 0 ], ids = id_formatter ("dim2" ))
140150@pytest .mark .parametrize ("dim3" , [64 ], ids = id_formatter ("dim3" ))
@@ -147,6 +157,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
147157@pytest .mark .parametrize ("compress_statistics" , TRUE_FALSE , ids = id_formatter ("compress_statistics" ))
148158@pytest .mark .parametrize ("quant_type" , ["fp4" , "nf4" ], ids = id_formatter ("quant_type" ))
149159def test_matmul_4bit (
160+ device ,
150161 dim1 ,
151162 dim2 ,
152163 dim3 ,
@@ -159,6 +170,9 @@ def test_matmul_4bit(
159170 compress_statistics ,
160171 quant_type ,
161172):
173+ if device == "cpu" and quant_type == "fp4" :
174+ pytest .skip ("Only nf4 is supported on CPU" )
175+
162176 dimA = (dim2 , dim3 ) if not transpose [0 ] else (dim3 , dim2 )
163177 dimB = (dim3 , dim4 ) if not transpose [1 ] else (dim4 , dim3 )
164178 if has_bias == False :
@@ -168,13 +182,13 @@ def test_matmul_4bit(
168182 for i in range (3 ):
169183 # normal multiply
170184 if funcs [0 ] in [torch .mm , torch .matmul ]:
171- A = torch .randn (size = dimA , device = "cuda" , requires_grad = req_grad [0 ], dtype = dtype )
172- B = torch .randn (size = dimB , device = "cuda" , requires_grad = req_grad [1 ], dtype = dtype )
173- target = torch .randn (size = (dim2 , dim4 ), device = "cuda" , requires_grad = req_grad [1 ], dtype = dtype )
185+ A = torch .randn (size = dimA , device = device , requires_grad = req_grad [0 ], dtype = dtype )
186+ B = torch .randn (size = dimB , device = device , requires_grad = req_grad [1 ], dtype = dtype )
187+ target = torch .randn (size = (dim2 , dim4 ), device = device , requires_grad = req_grad [1 ], dtype = dtype )
174188 bias = None
175189 bias2 = None
176190 if has_bias :
177- bias = torch .randn (dim4 , device = "cuda" , dtype = dtype , requires_grad = req_grad [2 ])
191+ bias = torch .randn (dim4 , device = device , dtype = dtype , requires_grad = req_grad [2 ])
178192 bias2 = bias .clone ()
179193 torch .nn .init .xavier_uniform_ (B )
180194
@@ -204,7 +218,8 @@ def test_matmul_4bit(
204218 # assert err < 0.20
205219 if any (req_grad ):
206220 out_bnb .data .copy_ (out_torch )
207- torch .cuda .synchronize ()
221+ if device == "cuda" :
222+ torch .cuda .synchronize ()
208223 loss_bnb = torch .nn .functional .mse_loss (out_bnb , target ).mean ()
209224 loss_bnb .backward ()
210225 gradA1 = A .grad
0 commit comments