Skip to content

Commit 9d452dc

Browse files
committed
update tests
1 parent 0ea904e commit 9d452dc

File tree

3 files changed

+21
-69
lines changed

3 files changed

+21
-69
lines changed

tests/pipelines/cogvideo/test_cogvideox.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
from transformers import AutoTokenizer, T5EncoderModel
2222

2323
from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, DDIMScheduler
24-
from diffusers.pipelines.pyramid_attention_broadcast_utils import PyramidAttentionBroadcastAttentionProcessorWrapper
24+
from diffusers.pipelines.pyramid_attention_broadcast_utils import (
25+
PyramidAttentionBroadcastConfig,
26+
apply_pyramid_attention_broadcast,
27+
)
2528
from diffusers.utils.testing_utils import (
2629
enable_full_determinism,
2730
numpy_cosine_similarity_distance,
@@ -333,40 +336,21 @@ def test_pyramid_attention_broadcast(self):
333336
frames = pipe(**inputs).frames # [B, F, C, H, W]
334337
original_image_slice = frames[0, -2:, -1, -3:, -3:]
335338

336-
pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=(100, 800))
337-
assert pipe.pyramid_attention_broadcast_enabled
338-
339-
num_pab_processors = sum(
340-
[
341-
isinstance(processor, PyramidAttentionBroadcastAttentionProcessorWrapper)
342-
for processor in pipe.transformer.attn_processors.values()
343-
]
339+
config = PyramidAttentionBroadcastConfig(
340+
spatial_attention_block_skip_range=2, spatial_attention_timestep_skip_range=(100, 800)
344341
)
345-
assert num_pab_processors == num_layers
342+
apply_pyramid_attention_broadcast(pipe, config)
346343

347344
inputs = self.get_dummy_inputs(device)
348345
inputs["num_inference_steps"] = 4
349346
frames = pipe(**inputs).frames
350347
image_slice_pab_enabled = frames[0, -2:, -1, -3:, -3:]
351348

352-
pipe.disable_pyramid_attention_broadcast()
353-
assert not pipe.pyramid_attention_broadcast_enabled
354-
355-
inputs = self.get_dummy_inputs(device)
356-
frames = pipe(**inputs).frames
357-
image_slice_pab_disabled = frames[0, -2:, -1, -3:, -3:]
358-
359349
# We need to use higher tolerance because we are using a random model. With a converged/trained
360350
# model, the tolerance can be lower.
361351
assert np.allclose(
362352
original_image_slice, image_slice_pab_enabled, atol=0.2
363353
), "PAB outputs should not differ much in specified timestep range."
364-
assert np.allclose(
365-
image_slice_pab_enabled, image_slice_pab_disabled, atol=0.2
366-
), "Outputs, with PAB enabled, shouldn't differ much when PAB is disabled in specified timestep range."
367-
assert np.allclose(
368-
original_image_slice, image_slice_pab_disabled, atol=0.2
369-
), "Original outputs should match when PAB is disabled."
370354

371355

372356
@slow

tests/pipelines/cogvideo/test_cogvideox_image2video.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
from transformers import AutoTokenizer, T5EncoderModel
2323

2424
from diffusers import AutoencoderKLCogVideoX, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel, DDIMScheduler
25-
from diffusers.pipelines.pyramid_attention_broadcast_utils import PyramidAttentionBroadcastAttentionProcessorWrapper
25+
from diffusers.pipelines.pyramid_attention_broadcast_utils import (
26+
PyramidAttentionBroadcastConfig,
27+
apply_pyramid_attention_broadcast,
28+
)
2629
from diffusers.utils import load_image
2730
from diffusers.utils.testing_utils import (
2831
enable_full_determinism,
@@ -356,40 +359,21 @@ def test_pyramid_attention_broadcast(self):
356359
frames = pipe(**inputs).frames # [B, F, C, H, W]
357360
original_image_slice = frames[0, -2:, -1, -3:, -3:]
358361

359-
pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=(100, 800))
360-
assert pipe.pyramid_attention_broadcast_enabled
361-
362-
num_pab_processors = sum(
363-
[
364-
isinstance(processor, PyramidAttentionBroadcastAttentionProcessorWrapper)
365-
for processor in pipe.transformer.attn_processors.values()
366-
]
362+
config = PyramidAttentionBroadcastConfig(
363+
spatial_attention_block_skip_range=2, spatial_attention_timestep_skip_range=(100, 800)
367364
)
368-
assert num_pab_processors == num_layers
365+
apply_pyramid_attention_broadcast(pipe, config)
369366

370367
inputs = self.get_dummy_inputs(device)
371368
inputs["num_inference_steps"] = 4
372369
frames = pipe(**inputs).frames
373370
image_slice_pab_enabled = frames[0, -2:, -1, -3:, -3:]
374371

375-
pipe.disable_pyramid_attention_broadcast()
376-
assert not pipe.pyramid_attention_broadcast_enabled
377-
378-
inputs = self.get_dummy_inputs(device)
379-
frames = pipe(**inputs).frames
380-
image_slice_pab_disabled = frames[0, -2:, -1, -3:, -3:]
381-
382372
# We need to use higher tolerance because we are using a random model. With a converged/trained
383373
# model, the tolerance can be lower.
384374
assert np.allclose(
385375
original_image_slice, image_slice_pab_enabled, atol=0.2
386376
), "PAB outputs should not differ much in specified timestep range."
387-
assert np.allclose(
388-
image_slice_pab_enabled, image_slice_pab_disabled, atol=0.2
389-
), "Outputs, with PAB enabled, shouldn't differ much when PAB is disabled in specified timestep range."
390-
assert np.allclose(
391-
original_image_slice, image_slice_pab_disabled, atol=0.2
392-
), "Original outputs should match when PAB is disabled."
393377

394378

395379
@slow

tests/pipelines/cogvideo/test_cogvideox_video2video.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
from transformers import AutoTokenizer, T5EncoderModel
2222

2323
from diffusers import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel, CogVideoXVideoToVideoPipeline, DDIMScheduler
24-
from diffusers.pipelines.pyramid_attention_broadcast_utils import PyramidAttentionBroadcastAttentionProcessorWrapper
24+
from diffusers.pipelines.pyramid_attention_broadcast_utils import (
25+
PyramidAttentionBroadcastConfig,
26+
apply_pyramid_attention_broadcast,
27+
)
2528
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
2629

2730
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
@@ -338,37 +341,18 @@ def test_pyramid_attention_broadcast(self):
338341
frames = pipe(**inputs).frames # [B, F, C, H, W]
339342
original_image_slice = frames[0, -2:, -1, -3:, -3:]
340343

341-
pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=(100, 800))
342-
assert pipe.pyramid_attention_broadcast_enabled
343-
344-
num_pab_processors = sum(
345-
[
346-
isinstance(processor, PyramidAttentionBroadcastAttentionProcessorWrapper)
347-
for processor in pipe.transformer.attn_processors.values()
348-
]
344+
config = PyramidAttentionBroadcastConfig(
345+
spatial_attention_block_skip_range=2, spatial_attention_timestep_skip_range=(100, 800)
349346
)
350-
assert num_pab_processors == num_layers
347+
apply_pyramid_attention_broadcast(pipe, config)
351348

352349
inputs = self.get_dummy_inputs(device)
353350
inputs["num_inference_steps"] = 4
354351
frames = pipe(**inputs).frames
355352
image_slice_pab_enabled = frames[0, -2:, -1, -3:, -3:]
356353

357-
pipe.disable_pyramid_attention_broadcast()
358-
assert not pipe.pyramid_attention_broadcast_enabled
359-
360-
inputs = self.get_dummy_inputs(device)
361-
frames = pipe(**inputs).frames
362-
image_slice_pab_disabled = frames[0, -2:, -1, -3:, -3:]
363-
364354
# We need to use higher tolerance because we are using a random model. With a converged/trained
365355
# model, the tolerance can be lower.
366356
assert np.allclose(
367357
original_image_slice, image_slice_pab_enabled, atol=0.2
368358
), "PAB outputs should not differ much in specified timestep range."
369-
assert np.allclose(
370-
image_slice_pab_enabled, image_slice_pab_disabled, atol=0.2
371-
), "Outputs, with PAB enabled, shouldn't differ much when PAB is disabled in specified timestep range."
372-
assert np.allclose(
373-
original_image_slice, image_slice_pab_disabled, atol=0.2
374-
), "Original outputs should match when PAB is disabled."

0 commit comments

Comments
 (0)