7575    require_torch_2 ,
7676    require_torch_accelerator ,
7777    require_torch_accelerator_with_training ,
78-     require_torch_gpu ,
7978    require_torch_multi_accelerator ,
8079    require_torch_version_greater ,
8180    run_test_in_subprocess ,
@@ -1829,8 +1828,8 @@ def test_wrong_device_map_raises_error(self, device_map, msg_substring):
18291828
18301829        assert  msg_substring  in  str (err_ctx .exception )
18311830
1832-     @parameterized .expand ([0 , "cuda" , torch .device ("cuda" )]) 
1833-     @require_torch_gpu  
1831+     @parameterized .expand ([0 , torch_device , torch .device (torch_device )]) 
1832+     @require_torch_accelerator  
18341833    def  test_passing_non_dict_device_map_works (self , device_map ):
18351834        init_dict , inputs_dict  =  self .prepare_init_args_and_inputs_for_common ()
18361835        model  =  self .model_class (** init_dict ).eval ()
@@ -1839,8 +1838,8 @@ def test_passing_non_dict_device_map_works(self, device_map):
18391838            loaded_model  =  self .model_class .from_pretrained (tmpdir , device_map = device_map )
18401839            _  =  loaded_model (** inputs_dict )
18411840
1842-     @parameterized .expand ([("" , "cuda" ), ("" , torch .device ("cuda" ))]) 
1843-     @require_torch_gpu  
1841+     @parameterized .expand ([("" , torch_device ), ("" , torch .device (torch_device ))]) 
1842+     @require_torch_accelerator  
18441843    def  test_passing_dict_device_map_works (self , name , device ):
18451844        # There are other valid dict-based `device_map` values too. It's best to refer to 
18461845        # the docs for those: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap. 
@@ -1945,7 +1944,7 @@ def test_push_to_hub_library_name(self):
19451944        delete_repo (self .repo_id , token = TOKEN )
19461945
19471946
1948- @require_torch_gpu  
1947+ @require_torch_accelerator  
19491948@require_torch_2  
19501949@is_torch_compile  
19511950@slow  
@@ -2013,7 +2012,7 @@ def test_compile_with_group_offloading(self):
20132012        model .eval ()
20142013        # TODO: Can test for other group offloading kwargs later if needed. 
20152014        group_offload_kwargs  =  {
2016-             "onload_device" : "cuda" ,
2015+             "onload_device" : torch_device ,
20172016            "offload_device" : "cpu" ,
20182017            "offload_type" : "block_level" ,
20192018            "num_blocks_per_group" : 1 ,
0 commit comments