@@ -284,7 +284,8 @@ def test_linear_kbit_fp32_bias(device, module):
284284
285285@pytest .mark .parametrize ("device" , get_available_devices ())
286286@pytest .mark .parametrize ("module" , module_dict .values (), ids = module_dict .keys ())
287- def test_kbit_backprop (device , module ):
287+ @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
288+ def test_kbit_backprop (device , module , dtype ):
288289 b = 16
289290 dim1 = 36
290291 dim2 = 84
@@ -298,24 +299,28 @@ def test_kbit_backprop(device, module):
298299
299300 kbit = nn .Sequential (* [torch .nn .Linear (dim1 , dim2 ), module (dim2 , 128 )])
300301
301- if device == "hpu" and isinstance (kbit [1 ], bnb .nn .Linear4bit ) and kbit [1 ].weight .quant_type == "fp4" :
302- pytest .skip ("FP4 is not supported on HPU" )
302+ if (
303+ device == "hpu"
304+ and isinstance (kbit [1 ], bnb .nn .Linear4bit )
305+ and not is_supported_on_hpu (kbit [1 ].weight .quant_type , dtype )
306+ ):
307+ pytest .skip ("This configuration not supported on HPU" )
303308
304309 kbit [0 ].weight .detach ().copy_ (ref [0 ].weight )
305310 kbit [1 ].weight .detach ().copy_ (ref [1 ].weight )
306311 kbit [0 ].bias .detach ().copy_ (ref [0 ].bias )
307312 kbit [1 ].bias .detach ().copy_ (ref [1 ].bias )
308313 kbit [1 ].weight .requires_grad_ (False )
309- ref = ref .half (). to (device )
310- kbit = kbit .half (). to (device )
311- kbit = kbit .half (). to (device )
314+ ref = ref .to (device = device , dtype = dtype )
315+ kbit = kbit .to (device = device , dtype = dtype )
316+ kbit = kbit .to (device = device , dtype = dtype )
312317
313318 errs1 = []
314319 errs2 = []
315320 relerrs1 = []
316321 relerrs2 = []
317322 for i in range (100 ):
318- batch = torch .randn (b , dim1 , device = device , dtype = torch . float16 )
323+ batch = torch .randn (b , dim1 , device = device , dtype = dtype )
319324 out1 = ref (batch )
320325 out2 = kbit (batch )
321326 out1 .mean ().backward ()
0 commit comments