File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -215,7 +215,7 @@ def test_keep_modules_in_fp32(self):
215215 self .assertTrue (module .weight .dtype == torch .uint8 )
216216
217217 # test if inference works.
218- with torch .no_grad () and torch .amp .autocast ("cuda" , dtype = torch .float16 ):
218+ with torch .no_grad () and torch .amp .autocast (torch_device , dtype = torch .float16 ):
219219 input_dict_for_transformer = self .get_dummy_inputs ()
220220 model_inputs = {
221221 k : v .to (device = torch_device ) for k , v in input_dict_for_transformer .items () if not isinstance (v , bool )
@@ -389,7 +389,7 @@ def test_training(self):
389389 model_inputs .update ({k : v for k , v in input_dict_for_transformer .items () if k not in model_inputs })
390390
391391 # Step 4: Check if the gradient is not None
392- with torch .amp .autocast ("cuda" , dtype = torch .float16 ):
392+ with torch .amp .autocast (torch_device , dtype = torch .float16 ):
393393 out = self .model_4bit (** model_inputs )[0 ]
394394 out .norm ().backward ()
395395
You can’t perform that action at this time.
0 commit comments