2121
2222from diffusers import AutoencoderKL , DDIMScheduler , DiTPipeline , DiTTransformer2DModel , DPMSolverMultistepScheduler
2323from diffusers .utils import is_xformers_available
24- from diffusers .utils .testing_utils import enable_full_determinism , load_numpy , nightly , require_torch_gpu , torch_device
24+ from diffusers .utils .testing_utils import (
25+ backend_empty_cache ,
26+ enable_full_determinism ,
27+ load_numpy ,
28+ nightly ,
29+ numpy_cosine_similarity_distance ,
30+ require_torch_accelerator ,
31+ torch_device ,
32+ )
2533
2634from ..pipeline_params import (
2735 CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS ,
@@ -107,23 +115,23 @@ def test_xformers_attention_forwardGenerator_pass(self):
107115
108116
109117@nightly
110- @require_torch_gpu
118+ @require_torch_accelerator
111119class DiTPipelineIntegrationTests (unittest .TestCase ):
112120 def setUp (self ):
113121 super ().setUp ()
114122 gc .collect ()
115- torch . cuda . empty_cache ( )
123+ backend_empty_cache ( torch_device )
116124
117125 def tearDown (self ):
118126 super ().tearDown ()
119127 gc .collect ()
120- torch . cuda . empty_cache ( )
128+ backend_empty_cache ( torch_device )
121129
122130 def test_dit_256 (self ):
123131 generator = torch .manual_seed (0 )
124132
125133 pipe = DiTPipeline .from_pretrained ("facebook/DiT-XL-2-256" )
126- pipe .to ("cuda" )
134+ pipe .to (torch_device )
127135
128136 words = ["vase" , "umbrella" , "white shark" , "white wolf" ]
129137 ids = pipe .get_label_ids (words )
@@ -139,7 +147,7 @@ def test_dit_256(self):
139147 def test_dit_512 (self ):
140148 pipe = DiTPipeline .from_pretrained ("facebook/DiT-XL-2-512" )
141149 pipe .scheduler = DPMSolverMultistepScheduler .from_config (pipe .scheduler .config )
142- pipe .to ("cuda" )
150+ pipe .to (torch_device )
143151
144152 words = ["vase" , "umbrella" ]
145153 ids = pipe .get_label_ids (words )
@@ -152,4 +160,7 @@ def test_dit_512(self):
152160 f"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/dit/{ word } _512.npy"
153161 )
154162
155- assert np .abs ((expected_image - image ).max ()) < 1e-1
163+ expected_slice = expected_image .flatten ()
164+ output_slice = image .flatten ()
165+
166+ assert numpy_cosine_similarity_distance (expected_slice , output_slice ) < 1e-2
0 commit comments