@@ -1828,8 +1828,8 @@ def test_wrong_device_map_raises_error(self, device_map, msg_substring):
18281828
18291829 assert msg_substring in str (err_ctx .exception )
18301830
1831- @parameterized .expand ([0 , "cuda" , torch .device ("cuda" )])
1832- @require_torch_gpu
1831+ @parameterized .expand ([0 , torch_device , torch .device (torch_device )])
1832+ @require_torch_accelerator
18331833 def test_passing_non_dict_device_map_works (self , device_map ):
18341834 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
18351835 model = self .model_class (** init_dict ).eval ()
@@ -1838,8 +1838,8 @@ def test_passing_non_dict_device_map_works(self, device_map):
18381838 loaded_model = self .model_class .from_pretrained (tmpdir , device_map = device_map )
18391839 _ = loaded_model (** inputs_dict )
18401840
1841- @parameterized .expand ([("" , "cuda" ), ("" , torch .device ("cuda" ))])
1842- @require_torch_gpu
1841+ @parameterized .expand ([("" , torch_device ), ("" , torch .device (torch_device ))])
1842+ @require_torch_accelerator
18431843 def test_passing_dict_device_map_works (self , name , device ):
18441844 # There are other valid dict-based `device_map` values too. It's best to refer to
18451845 # the docs for those: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap.
0 commit comments