Skip to content

Commit 3e13347

Browse files
committed
add 2 more cases
1 parent dada34c commit 3e13347

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/quantization/bnb/test_4bit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def test_keep_modules_in_fp32(self):
215215
self.assertTrue(module.weight.dtype == torch.uint8)
216216

217217
# test if inference works.
218-
with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16):
218+
with torch.no_grad() and torch.amp.autocast(torch_device, dtype=torch.float16):
219219
input_dict_for_transformer = self.get_dummy_inputs()
220220
model_inputs = {
221221
k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)
@@ -389,7 +389,7 @@ def test_training(self):
389389
model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
390390

391391
# Step 4: Check if the gradient is not None
392-
with torch.amp.autocast("cuda", dtype=torch.float16):
392+
with torch.amp.autocast(torch_device, dtype=torch.float16):
393393
out = self.model_4bit(**model_inputs)[0]
394394
out.norm().backward()
395395

0 commit comments

Comments
 (0)