Skip to content

Commit a1b3331

Browse files
HPU test update
1 parent 214c3f3 commit a1b3331

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

tests/test_modules.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,9 +298,8 @@ def test_kbit_backprop(device, module):
298298

299299
kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 128)])
300300

301-
if device == "hpu":
302-
if isinstance(kbit, bnb.nn.LinearFP4):
303-
pytest.skip("FP4 is not supported on HPU")
301+
if device == "hpu" and isinstance(kbit[1], bnb.nn.LinearFP4):
302+
pytest.skip("FP4 is not supported on HPU")
304303

305304
kbit[0].weight.detach().copy_(ref[0].weight)
306305
kbit[1].weight.detach().copy_(ref[1].weight)

0 commit comments

Comments
 (0)