77from  diffusers  import  EulerDiscreteScheduler , StableDiffusionPipeline 
88from  diffusers .loaders .single_file_utils  import  _extract_repo_id_and_weights_name 
99from  diffusers .utils .testing_utils  import  (
10+     backend_empty_cache ,
1011    enable_full_determinism ,
11-     require_torch_gpu ,
12+     require_torch_accelerator ,
1213    slow ,
14+     torch_device ,
1315)
1416
1517from  .single_file_testing_utils  import  (
2325
2426
2527@slow  
26- @require_torch_gpu  
28+ @require_torch_accelerator  
2729class  StableDiffusionPipelineSingleFileSlowTests (unittest .TestCase , SDSingleFileTesterMixin ):
2830    pipeline_class  =  StableDiffusionPipeline 
2931    ckpt_path  =  "https://huggingface.co/Jiali/stable-diffusion-1.5/blob/main/v1-5-pruned-emaonly.safetensors" 
@@ -35,12 +37,12 @@ class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFile
3537    def  setUp (self ):
3638        super ().setUp ()
3739        gc .collect ()
38-         torch . cuda . empty_cache ( )
40+         backend_empty_cache ( torch_device )
3941
4042    def  tearDown (self ):
4143        super ().tearDown ()
4244        gc .collect ()
43-         torch . cuda . empty_cache ( )
45+         backend_empty_cache ( torch_device )
4446
4547    def  get_inputs (self , device , generator_device = "cpu" , dtype = torch .float32 , seed = 0 ):
4648        generator  =  torch .Generator (device = generator_device ).manual_seed (seed )
@@ -93,12 +95,12 @@ class StableDiffusion21PipelineSingleFileSlowTests(unittest.TestCase, SDSingleFi
9395    def  setUp (self ):
9496        super ().setUp ()
9597        gc .collect ()
96-         torch . cuda . empty_cache ( )
98+         backend_empty_cache ( torch_device )
9799
98100    def  tearDown (self ):
99101        super ().tearDown ()
100102        gc .collect ()
101-         torch . cuda . empty_cache ( )
103+         backend_empty_cache ( torch_device )
102104
103105    def  get_inputs (self , device , generator_device = "cpu" , dtype = torch .float32 , seed = 0 ):
104106        generator  =  torch .Generator (device = generator_device ).manual_seed (seed )
0 commit comments