@@ -71,8 +71,11 @@ def test_linear8bitlt_inference(device, threshold):
7171
7272
7373# TODO: Remove support for training int8 weights
74- @pytest .mark .deprecated
74+ @pytest .mark .parametrize ( "device" , get_available_devices ())
7575def test_linear8bitlt_accumulated_gradient (device ):
76+ if device != "cuda" :
77+ pytest .skip ("Only supported on CUDA" )
78+
7679 l1 = torch .nn .Sequential (* [bnb .nn .Linear8bitLt (32 , 32 ).to (device ).half () for i in range (2 )])
7780 l2 = torch .nn .Sequential (* [torch .nn .Linear (32 , 32 ).to (device ).half () for i in range (2 )])
7881 l1 [0 ].weight .data .copy_ (l2 [0 ].weight .data )
@@ -114,56 +117,60 @@ def test_linear8bitlt_accumulated_gradient(device):
114117 assert_all_approx_close (l1 [1 ].weight .grad , l2 [1 ].weight .grad , rtol = 1.05 , atol = 0.04 , count = 1 )
115118
116119
120+ @pytest .mark .parametrize ("device" , get_available_devices ())
117121@pytest .mark .parametrize ("threshold" , [0.0 , 2.0 ])
118- def test_linear8bitlt_no_fp16_weights (threshold ):
122+ def test_linear8bitlt_no_fp16_weights (device , threshold ):
123+ if device == "cpu" :
124+ pytest .xfail ("Not yet supported on CPU" )
125+
119126 l1 = (
120127 bnb .nn .Linear8bitLt (
121128 32 ,
122129 64 ,
123130 threshold = threshold ,
124131 has_fp16_weights = False ,
125132 )
126- .cuda ( )
133+ .to ( device )
127134 .half ()
128135 )
129136 assert l1 .weight .dtype == torch .int8
130137
131138 l1 .eval ()
132139 for i in range (100 ):
133- b1 = torch .randn (16 , 8 , 32 , device = "cuda" ). half ( )
140+ b1 = torch .randn (16 , 8 , 32 , device = device , dtype = torch . float16 )
134141 o1 = l1 (b1 )
135142 assert o1 .dtype == torch .float16
136143
137- mlp = MLP8bit (32 , 64 , threshold = threshold , has_fp16_weights = False ).cuda ( )
144+ mlp = MLP8bit (32 , 64 , threshold = threshold , has_fp16_weights = False ).to ( device )
138145 assert mlp .fc1 .weight .dtype == torch .int8
139146 assert mlp .fc2 .weight .dtype == torch .int8
140147
141148 for i in range (100 ):
142- b1 = torch .randn (16 , 8 , 32 , device = "cuda" ). half ( )
149+ b1 = torch .randn (16 , 8 , 32 , device = device , dtype = torch . float16 )
143150 o1 = mlp (b1 )
144151 assert o1 .dtype == torch .float16
145152 if threshold > 0 :
146153 assert mlp .fc1 .state .idx is not None
147154 if threshold > 0 :
148155 assert mlp .fc2 .state .idx is not None
149156
150- mlp = MLP8bit (32 , 64 , threshold = threshold , has_fp16_weights = False ).cuda ( ).half ()
157+ mlp = MLP8bit (32 , 64 , threshold = threshold , has_fp16_weights = False ).to ( device ).half ()
151158 assert mlp .fc1 .weight .dtype == torch .int8
152159 assert mlp .fc2 .weight .dtype == torch .int8
153160
154161 for i in range (100 ):
155- b1 = torch .randn (16 , 8 , 32 , device = "cuda" ). half ( )
162+ b1 = torch .randn (16 , 8 , 32 , device = device , dtype = torch . float16 )
156163 o1 = mlp (b1 )
157164 assert o1 .dtype == torch .float16
158165 if threshold > 0 :
159166 assert mlp .fc1 .state .idx is not None
160167 if threshold > 0 :
161168 assert mlp .fc2 .state .idx is not None
162169
163- mlp = MLP8bit (32 , 64 , threshold = threshold , has_fp16_weights = False ).half ().cuda ( )
170+ mlp = MLP8bit (32 , 64 , threshold = threshold , has_fp16_weights = False ).half ().to ( device )
164171
165172 for i in range (100 ):
166- b1 = torch .randn (16 , 8 , 32 , device = "cuda" ). half ( )
173+ b1 = torch .randn (16 , 8 , 32 , device = device , dtype = torch . float16 )
167174 o1 = mlp (b1 )
168175 assert o1 .dtype == torch .float16
169176 if threshold > 0 :
@@ -181,11 +188,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
181188 has_fp16_weights = False ,
182189 )
183190 .half ()
184- .to ("cuda" )
191+ .to (device )
185192 )
186193
187194 for i in range (100 ):
188- b1 = torch .randn (16 , 8 , 32 , device = "cuda" ). half ( )
195+ b1 = torch .randn (16 , 8 , 32 , device = device , dtype = torch . float16 )
189196 o1 = mlp (b1 )
190197 assert o1 .dtype == torch .float16
191198 if threshold > 0 :
@@ -194,20 +201,20 @@ def test_linear8bitlt_no_fp16_weights(threshold):
194201 assert mlp .fc2 .state .idx is not None
195202 assert mlp .fc1 .weight .dtype == torch .int8
196203 assert mlp .fc2 .weight .dtype == torch .int8
197- assert mlp .fc1 .weight .device .type == "cuda"
198- assert mlp .fc2 .weight .device .type == "cuda"
204+ assert mlp .fc1 .weight .device .type == device
205+ assert mlp .fc2 .weight .device .type == device
199206
200207 mlp = MLP8bit (
201208 32 ,
202209 64 ,
203210 threshold = threshold ,
204211 has_fp16_weights = False ,
205212 )
206- w1 , w2 = mlp .fc1 .weight .clone ().cuda ( ), mlp .fc2 .weight .clone ().cuda ( ) # grab weights before quantization,
213+ w1 , w2 = mlp .fc1 .weight .clone ().to ( device ), mlp .fc2 .weight .clone ().to ( device ) # grab weights before quantization,
207214 mlp = mlp .cuda ().half () # and this line triggers quantization
208215
209216 for i in range (100 ):
210- b1 = torch .randn (16 , 8 , 32 , device = "cuda" ). half ( )
217+ b1 = torch .randn (16 , 8 , 32 , device = device , dtype = torch . float16 )
211218 o1 = mlp (b1 )
212219 assert o1 .dtype == torch .float16
213220 if threshold > 0 :
@@ -217,10 +224,10 @@ def test_linear8bitlt_no_fp16_weights(threshold):
217224
218225 assert mlp .fc1 .weight .dtype == torch .int8
219226 assert mlp .fc2 .weight .dtype == torch .int8
220- assert mlp .fc1 .weight .device .type == "cuda"
221- assert mlp .fc2 .weight .device .type == "cuda"
227+ assert mlp .fc1 .weight .device .type == device
228+ assert mlp .fc2 .weight .device .type == device
222229
223- b1 = torch .randn (16 , 8 , 32 , device = "cuda" , requires_grad = True , dtype = torch .half )
230+ b1 = torch .randn (16 , 8 , 32 , device = device , requires_grad = True , dtype = torch .half )
224231 o1 = mlp (b1 )
225232 assert o1 .dtype == torch .float16
226233 assert o1 .requires_grad
@@ -236,33 +243,37 @@ def test_linear8bitlt_no_fp16_weights(threshold):
236243 assert (idx == 0 ).sum ().item () <= b1 .numel () * 0.005
237244
238245
246+ @pytest .mark .parametrize ("device" , get_available_devices ())
239247@pytest .mark .parametrize (
240248 "module" ,
241249 [
242250 lambda n_in , n_out , bias = True : bnb .nn .Linear8bitLt (n_in , n_out , bias = bias , has_fp16_weights = False ),
243- bnb .nn .LinearFP4 ,
251+ bnb .nn .LinearNF4 ,
244252 ],
245- ids = ["Int8Lt" , "FP4 " ],
253+ ids = ["Int8Lt" , "NF4 " ],
246254)
247- def test_linear_kbit_fp32_bias (module ):
255+ def test_linear_kbit_fp32_bias (device , module ):
256+ if device == "cpu" :
257+ pytest .xfail ("Not yet implemented on CPU" )
258+
248259 # casts model to fp16 -> int8 automatically
249- l1 = module (32 , 64 ).cuda ( )
260+ l1 = module (32 , 64 ).to ( device )
250261 assert l1 .weight .dtype in [torch .int8 , torch .uint8 ]
251262 assert l1 .bias .dtype == torch .float32
252263
253264 for i in range (100 ):
254- b1 = torch .randn (16 , 8 , 32 , device = "cuda" ). half ( )
265+ b1 = torch .randn (16 , 8 , 32 , device = device , dtype = torch . float16 )
255266 # casts bias to fp32
256267 o1 = l1 (b1 )
257268 assert l1 .bias .dtype == torch .float16
258269
259270 # casts model to fp16 -> int8 automatically
260- l1 = module (32 , 64 , bias = False ).cuda ( )
271+ l1 = module (32 , 64 , bias = False ).to ( device )
261272 assert l1 .weight .dtype in [torch .int8 , torch .uint8 ]
262273 assert l1 .bias is None
263274
264275 for i in range (100 ):
265- b1 = torch .randn (16 , 8 , 32 , device = "cuda" ). half ( )
276+ b1 = torch .randn (16 , 8 , 32 , device = device , dtype = torch . float16 )
266277 o1 = l1 (b1 )
267278 assert l1 .bias is None
268279
@@ -280,8 +291,12 @@ def test_linear_kbit_fp32_bias(module):
280291}
281292
282293
294+ @pytest .mark .parametrize ("device" , get_available_devices ())
283295@pytest .mark .parametrize ("module" , module_dict .values (), ids = module_dict .keys ())
284- def test_kbit_backprop (module ):
296+ def test_kbit_backprop (device , module ):
297+ if device == "cpu" :
298+ pytest .xfail ("Not yet implemented on CPU" )
299+
285300 b = 16
286301 dim1 = 36
287302 dim2 = 84
@@ -297,16 +312,16 @@ def test_kbit_backprop(module):
297312 kbit [1 ].weight .detach ().copy_ (ref [1 ].weight )
298313 kbit [0 ].bias .detach ().copy_ (ref [0 ].bias )
299314 kbit [1 ].bias .detach ().copy_ (ref [1 ].bias )
300- ref = ref .half ().cuda ( )
301- kbit = kbit .half ().cuda ( )
302- kbit = kbit .half ().to ("cuda" )
315+ ref = ref .half ().to ( device )
316+ kbit = kbit .half ().to ( device )
317+ kbit = kbit .half ().to (device )
303318
304319 errs1 = []
305320 errs2 = []
306321 relerrs1 = []
307322 relerrs2 = []
308323 for i in range (100 ):
309- batch = torch .randn (b , dim1 ). half (). cuda ( )
324+ batch = torch .randn (b , dim1 , device = device , dtype = torch . float16 )
310325 out1 = ref (batch )
311326 out2 = kbit (batch )
312327 out1 .mean ().backward ()
@@ -339,6 +354,7 @@ def test_kbit_backprop(module):
339354 assert kbit [0 ].weight .grad is None or kbit [0 ].bias .grad .sum ().item () == 0
340355
341356
357+ @pytest .mark .deprecated
342358def test_fp8linear ():
343359 b = 10
344360 h = 1024
@@ -369,6 +385,7 @@ def test_fp8linear():
369385 assert bgraderr < 0.00002
370386
371387
388+ @pytest .mark .parametrize ("device" , get_available_devices ())
372389@pytest .mark .parametrize ("embedding_dim" , [64 , 65 ])
373390@pytest .mark .parametrize ("input_shape" , [(10 ,), (10 , 10 ), (10 , 10 , 10 )], ids = str )
374391@pytest .mark .parametrize (
@@ -382,7 +399,10 @@ def test_fp8linear():
382399 ],
383400 ids = lambda x : x .__name__ if inspect .isclass (x ) else str (x ),
384401)
385- def test_embedding_lossless (embedding_class , input_shape , embedding_dim , quant_storage ):
402+ def test_embedding_lossless (device , embedding_class , input_shape , embedding_dim , quant_storage ):
403+ if device == "cpu" :
404+ pytest .xfail ("Not yet supported on CPU" )
405+
386406 num_embeddings = 128
387407
388408 src_weight = (torch .randn ((num_embeddings , embedding_dim ), dtype = torch .float32 ) > 0 ).to (
@@ -402,17 +422,18 @@ def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_s
402422
403423 e .load_state_dict (emb_base .state_dict ())
404424
405- emb_base .cuda ( )
406- e .cuda ( )
425+ emb_base .to ( device )
426+ e .to ( device )
407427
408- input_tokens = torch .randint (low = 0 , high = num_embeddings , size = input_shape , device = "cuda" )
428+ input_tokens = torch .randint (low = 0 , high = num_embeddings , size = input_shape , device = device )
409429
410430 torch .testing .assert_close (
411431 actual = e (input_tokens ),
412432 expected = emb_base (input_tokens ),
413433 )
414434
415435
436+ @pytest .mark .parametrize ("device" , get_available_devices ())
416437@pytest .mark .parametrize ("embedding_dim" , [64 , 65 ])
417438@pytest .mark .parametrize ("input_shape" , [(10 ,), (10 , 10 ), (10 , 10 , 10 )], ids = str )
418439@pytest .mark .parametrize (
@@ -426,7 +447,10 @@ def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_s
426447 ],
427448 ids = lambda x : x .__name__ if inspect .isclass (x ) else str (x ),
428449)
429- def test_embedding_error (embedding_class , input_shape , embedding_dim , quant_storage ):
450+ def test_embedding_error (device , embedding_class , input_shape , embedding_dim , quant_storage ):
451+ if device == "cpu" :
452+ pytest .xfail ("Not yet supported on CPU" )
453+
430454 is_8bit = embedding_class is bnb .nn .Embedding8bit
431455
432456 num_embeddings = 128
@@ -446,10 +470,10 @@ def test_embedding_error(embedding_class, input_shape, embedding_dim, quant_stor
446470
447471 e .load_state_dict (emb_base .state_dict ())
448472
449- emb_base .cuda ( )
450- e .cuda ( )
473+ emb_base .to ( device )
474+ e .to ( device )
451475
452- input_tokens = torch .randint (low = 0 , high = num_embeddings , size = input_shape , device = "cuda" )
476+ input_tokens = torch .randint (low = 0 , high = num_embeddings , size = input_shape , device = device )
453477
454478 torch .testing .assert_close (
455479 actual = e (input_tokens ),
@@ -459,46 +483,64 @@ def test_embedding_error(embedding_class, input_shape, embedding_dim, quant_stor
459483 )
460484
461485
462- def test_4bit_linear_warnings ():
486+ @pytest .mark .parametrize ("device" , get_available_devices ())
487+ def test_4bit_linear_warnings (device ):
488+ if device == "cpu" :
489+ pytest .xfail ("Not yet implemented on CPU" )
490+
463491 dim1 = 64
464492
465493 with pytest .warns (UserWarning , match = r"inference or training" ):
466- net = nn .Sequential (* [bnb .nn .Linear4bit (dim1 , dim1 , compute_dtype = torch .float32 ) for i in range (10 )])
467- net = net .cuda ()
468- inp = torch .rand (10 , dim1 ).cuda ().half ()
494+ net = nn .Sequential (
495+ * [bnb .nn .Linear4bit (dim1 , dim1 , quant_type = "nf4" , compute_dtype = torch .float32 ) for i in range (10 )]
496+ )
497+ net = net .to (device )
498+ inp = torch .rand (10 , dim1 , device = device , dtype = torch .float16 )
469499 net (inp )
470500 with pytest .warns (UserWarning , match = r"inference." ):
471- net = nn .Sequential (* [bnb .nn .Linear4bit (dim1 , dim1 , compute_dtype = torch .float32 ) for i in range (10 )])
472- net = net .cuda ()
473- inp = torch .rand (1 , dim1 ).cuda ().half ()
501+ net = nn .Sequential (
502+ * [bnb .nn .Linear4bit (dim1 , dim1 , quant_type = "nf4" , compute_dtype = torch .float32 ) for i in range (10 )]
503+ )
504+ net = net .to (device )
505+ inp = torch .rand (1 , dim1 , device = device , dtype = torch .float16 )
474506 net (inp )
475507
476508 with pytest .warns (UserWarning ) as record :
477- net = nn .Sequential (* [bnb .nn .Linear4bit (dim1 , dim1 , compute_dtype = torch .float32 ) for i in range (10 )])
478- net = net .cuda ()
479- inp = torch .rand (10 , dim1 ).cuda ().half ()
509+ net = nn .Sequential (
510+ * [bnb .nn .Linear4bit (dim1 , dim1 , quant_type = "nf4" , compute_dtype = torch .float32 ) for i in range (10 )]
511+ )
512+ net = net .to (device )
513+ inp = torch .rand (10 , dim1 , device = device , dtype = torch .float16 )
480514 net (inp )
481515
482- net = nn .Sequential (* [bnb .nn .Linear4bit (dim1 , dim1 , compute_dtype = torch .float32 ) for i in range (10 )])
483- net = net .cuda ()
484- inp = torch .rand (1 , dim1 ).cuda ().half ()
516+ net = nn .Sequential (
517+ * [bnb .nn .Linear4bit (dim1 , dim1 , quant_type = "nf4" , compute_dtype = torch .float32 ) for i in range (10 )]
518+ )
519+ net = net .to (device )
520+ inp = torch .rand (1 , dim1 , device = device , dtype = torch .float16 )
485521 net (inp )
486522
487523 assert len (record ) == 2
488524
489525
490- def test_4bit_embedding_warnings ():
526+ @pytest .mark .parametrize ("device" , get_available_devices ())
527+ def test_4bit_embedding_warnings (device ):
528+ if device == "cpu" :
529+ pytest .xfail ("Not yet implemented on CPU" )
530+
491531 num_embeddings = 128
492532 default_block_size = 64
493533
494534 with pytest .warns (UserWarning , match = r"inference." ):
495- net = bnb .nn .Embedding4bit (num_embeddings = num_embeddings , embedding_dim = default_block_size + 1 )
496- net .cuda ()
497- inp = torch .randint (low = 0 , high = num_embeddings , size = (1 ,), device = "cuda" )
535+ net = bnb .nn .Embedding4bit (
536+ num_embeddings = num_embeddings , embedding_dim = default_block_size + 1 , quant_type = "nf4"
537+ )
538+ net .to (device )
539+ inp = torch .randint (low = 0 , high = num_embeddings , size = (1 ,), device = device )
498540 net (inp )
499541
500542
501- def test_4bit_embedding_weight_fsdp_fix ():
543+ def test_4bit_embedding_weight_fsdp_fix (requires_cuda ):
502544 num_embeddings = 64
503545 embedding_dim = 32
504546
@@ -515,7 +557,7 @@ def test_4bit_embedding_weight_fsdp_fix():
515557 assert module .weight .quant_state is not None
516558
517559
518- def test_4bit_linear_weight_fsdp_fix ():
560+ def test_4bit_linear_weight_fsdp_fix (requires_cuda ):
519561 inp_size = 64
520562 out_size = 32
521563
0 commit comments