Skip to content

Commit 31034b4

Browse files
authored
Update unit tests for HPU (#1682)
1 parent 29564ad commit 31034b4

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

tests/test_modules.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,8 @@ def test_linear_kbit_fp32_bias(device, module):
284284

285285
@pytest.mark.parametrize("device", get_available_devices())
286286
@pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys())
287-
def test_kbit_backprop(device, module):
287+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
288+
def test_kbit_backprop(device, module, dtype):
288289
b = 16
289290
dim1 = 36
290291
dim2 = 84
@@ -298,24 +299,28 @@ def test_kbit_backprop(device, module):
298299

299300
kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 128)])
300301

301-
if device == "hpu" and isinstance(kbit[1], bnb.nn.Linear4bit) and kbit[1].weight.quant_type == "fp4":
302-
pytest.skip("FP4 is not supported on HPU")
302+
if (
303+
device == "hpu"
304+
and isinstance(kbit[1], bnb.nn.Linear4bit)
305+
and not is_supported_on_hpu(kbit[1].weight.quant_type, dtype)
306+
):
307+
pytest.skip("This configuration not supported on HPU")
303308

304309
kbit[0].weight.detach().copy_(ref[0].weight)
305310
kbit[1].weight.detach().copy_(ref[1].weight)
306311
kbit[0].bias.detach().copy_(ref[0].bias)
307312
kbit[1].bias.detach().copy_(ref[1].bias)
308313
kbit[1].weight.requires_grad_(False)
309-
ref = ref.half().to(device)
310-
kbit = kbit.half().to(device)
311-
kbit = kbit.half().to(device)
314+
ref = ref.to(device=device, dtype=dtype)
315+
kbit = kbit.to(device=device, dtype=dtype)
316+
kbit = kbit.to(device=device, dtype=dtype)
312317

313318
errs1 = []
314319
errs2 = []
315320
relerrs1 = []
316321
relerrs2 = []
317322
for i in range(100):
318-
batch = torch.randn(b, dim1, device=device, dtype=torch.float16)
323+
batch = torch.randn(b, dim1, device=device, dtype=dtype)
319324
out1 = ref(batch)
320325
out2 = kbit(batch)
321326
out1.mean().backward()

0 commit comments

Comments
 (0)