77from diffusers import EulerDiscreteScheduler , StableDiffusionPipeline
88from diffusers .loaders .single_file_utils import _extract_repo_id_and_weights_name
99from diffusers .utils .testing_utils import (
10+ backend_empty_cache ,
1011 enable_full_determinism ,
11- require_torch_gpu ,
12+ require_torch_accelerator ,
1213 slow ,
14+ torch_device ,
1315)
1416
1517from .single_file_testing_utils import (
2325
2426
2527@slow
26- @require_torch_gpu
28+ @require_torch_accelerator
2729class StableDiffusionPipelineSingleFileSlowTests (unittest .TestCase , SDSingleFileTesterMixin ):
2830 pipeline_class = StableDiffusionPipeline
2931 ckpt_path = "https://huggingface.co/Jiali/stable-diffusion-1.5/blob/main/v1-5-pruned-emaonly.safetensors"
@@ -35,12 +37,12 @@ class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFile
3537 def setUp (self ):
3638 super ().setUp ()
3739 gc .collect ()
38- torch . cuda . empty_cache ( )
40+ backend_empty_cache ( torch_device )
3941
4042 def tearDown (self ):
4143 super ().tearDown ()
4244 gc .collect ()
43- torch . cuda . empty_cache ( )
45+ backend_empty_cache ( torch_device )
4446
4547 def get_inputs (self , device , generator_device = "cpu" , dtype = torch .float32 , seed = 0 ):
4648 generator = torch .Generator (device = generator_device ).manual_seed (seed )
@@ -93,12 +95,12 @@ class StableDiffusion21PipelineSingleFileSlowTests(unittest.TestCase, SDSingleFi
9395 def setUp (self ):
9496 super ().setUp ()
9597 gc .collect ()
96- torch . cuda . empty_cache ( )
98+ backend_empty_cache ( torch_device )
9799
98100 def tearDown (self ):
99101 super ().tearDown ()
100102 gc .collect ()
101- torch . cuda . empty_cache ( )
103+ backend_empty_cache ( torch_device )
102104
103105 def get_inputs (self , device , generator_device = "cpu" , dtype = torch .float32 , seed = 0 ):
104106 generator = torch .Generator (device = generator_device ).manual_seed (seed )
0 commit comments