Skip to content

Commit 2ba4b8f

Browse files
HPU test update
1 parent 0a7f959 commit 2ba4b8f

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/test_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def test_kbit_backprop(device, module):
298298

299299
kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 128)])
300300

301-
if device == "hpu" and isinstance(kbit[1], bnb.nn.LinearFP4):
301+
if device == "hpu" and isinstance(kbit[1], bnb.nn.Linear4bit) and kbit[1].weight.quant_type == "fp4":
302302
pytest.skip("FP4 is not supported on HPU")
303303

304304
kbit[0].weight.detach().copy_(ref[0].weight)

0 commit comments

Comments
 (0)