|
22 | 22 | import torch |
23 | 23 | from transformers import AutoTokenizer, T5EncoderModel |
24 | 24 |
|
25 | | -from diffusers import ( |
26 | | - AutoencoderKL, |
27 | | - DDIMScheduler, |
28 | | - LattePipeline, |
29 | | - LatteTransformer3DModel, |
| 25 | +from diffusers import AutoencoderKL, DDIMScheduler, LattePipeline, LatteTransformer3DModel |
| 26 | +from diffusers.pipelines.pyramid_attention_broadcast_utils import ( |
| 27 | + PyramidAttentionBroadcastConfig, |
| 28 | + apply_pyramid_attention_broadcast, |
30 | 29 | ) |
31 | 30 | from diffusers.utils.import_utils import is_xformers_available |
32 | 31 | from diffusers.utils.testing_utils import ( |
@@ -277,33 +276,24 @@ def test_pyramid_attention_broadcast(self): |
277 | 276 | frames = pipe(**inputs).frames # [B, F, C, H, W] |
278 | 277 | original_image_slice = frames[0, -2:, -1, -3:, -3:] |
279 | 278 |
|
280 | | - pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=(100, 800)) |
281 | | - assert pipe.pyramid_attention_broadcast_enabled |
| 279 | + config = PyramidAttentionBroadcastConfig( |
| 280 | + spatial_attention_block_skip_range=2, |
| 281 | + temporal_attention_block_skip_range=3, |
| 282 | + spatial_attention_timestep_skip_range=(100, 800), |
| 283 | + temporal_attention_timestep_skip_range=(100, 800), |
| 284 | + ) |
| 285 | + apply_pyramid_attention_broadcast(pipe, config) |
282 | 286 |
|
283 | 287 | inputs = self.get_dummy_inputs(device) |
284 | 288 | inputs["num_inference_steps"] = 4 |
285 | 289 | frames = pipe(**inputs).frames |
286 | 290 | image_slice_pab_enabled = frames[0, -2:, -1, -3:, -3:] |
287 | 291 |
|
288 | | - pipe.disable_pyramid_attention_broadcast() |
289 | | - assert not pipe.pyramid_attention_broadcast_enabled |
290 | | - |
291 | | - inputs = self.get_dummy_inputs(device) |
292 | | - frames = pipe(**inputs).frames |
293 | | - image_slice_pab_disabled = frames[0, -2:, -1, -3:, -3:] |
294 | | - |
295 | 292 | # We need to use higher tolerance because we are using a random model. With a converged/trained |
296 | 293 | # model, the tolerance can be lower. |
297 | 294 | assert np.allclose( |
298 | | - original_image_slice, image_slice_pab_enabled, atol=0.25 |
| 295 | + original_image_slice, image_slice_pab_enabled, atol=0.2 |
299 | 296 | ), "PAB outputs should not differ much in specified timestep range." |
300 | | - print((image_slice_pab_disabled - image_slice_pab_enabled).abs().max()) |
301 | | - assert np.allclose( |
302 | | - image_slice_pab_enabled, image_slice_pab_disabled, atol=0.25 |
303 | | - ), "Outputs, with PAB enabled, shouldn't differ much when PAB is disabled in specified timestep range." |
304 | | - assert np.allclose( |
305 | | - original_image_slice, image_slice_pab_disabled, atol=0.25 |
306 | | - ), "Original outputs should match when PAB is disabled." |
307 | 297 |
|
308 | 298 |
|
309 | 299 | @slow |
|
0 commit comments