Skip to content

Commit 0a7f959

Browse files
HPU test update
1 parent a1b3331 commit 0a7f959

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/test_modules.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -276,9 +276,9 @@ def test_linear_kbit_fp32_bias(device, module):
276276
"NF4": bnb.nn.LinearNF4,
277277
"FP4+C": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True),
278278
"NF4+C": lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True),
279-
"NF4+fp32": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float32),
280-
"NF4+fp16": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float16),
281-
"NF4+bf16": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.bfloat16),
279+
"NF4+fp32": lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compute_dtype=torch.float32),
280+
"NF4+fp16": lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compute_dtype=torch.float16),
281+
"NF4+bf16": lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compute_dtype=torch.bfloat16),
282282
}
283283

284284

0 commit comments

Comments
 (0)