File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -346,7 +346,7 @@ def test_keep_modules_in_fp32(self):
346346
347347 model = SD3Transformer2DModel .from_pretrained (
348348 "hf-internal-testing/tiny-sd3-pipe" , subfolder = "transformer" , torch_dtype = torch_dtype
349- ).to ("cuda" )
349+ ).to (torch_device )
350350
351351 for name , module in model .named_modules ():
352352 if isinstance (module , torch .nn .Linear ):
@@ -375,7 +375,7 @@ def get_dummy_inputs():
375375 }
376376
377377 # test if inference works.
378- with torch .no_grad () and torch .amp .autocast ("cuda" , dtype = torch_dtype ):
378+ with torch .no_grad () and torch .amp .autocast (torch_device , dtype = torch_dtype ):
379379 input_dict_for_transformer = get_dummy_inputs ()
380380 model_inputs = {
381381 k : v .to (device = torch_device ) for k , v in input_dict_for_transformer .items () if not isinstance (v , bool )
You can’t perform that action at this time.
0 commit comments