Skip to content

Commit 68a6211

Browse files
SunMarcsayakpaul
andauthored
Apply suggestions from code review
Co-authored-by: Sayak Paul <[email protected]>
1 parent 69dda9a commit 68a6211

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/models/test_modeling_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def test_keep_modules_in_fp32(self):
346346

347347
model = SD3Transformer2DModel.from_pretrained(
348348
"hf-internal-testing/tiny-sd3-pipe", subfolder="transformer", torch_dtype=torch_dtype
349-
).to("cuda")
349+
).to(torch_device)
350350

351351
for name, module in model.named_modules():
352352
if isinstance(module, torch.nn.Linear):
@@ -375,7 +375,7 @@ def get_dummy_inputs():
375375
}
376376

377377
# test if inference works.
378-
with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch_dtype):
378+
with torch.no_grad() and torch.amp.autocast(torch_device, dtype=torch_dtype):
379379
input_dict_for_transformer = get_dummy_inputs()
380380
model_inputs = {
381381
k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)

0 commit comments

Comments
 (0)