Skip to content

Commit bbcde6b

Browse files
committed
update tests
1 parent ad24269 commit bbcde6b

File tree

7 files changed

+164
-18
lines changed

7 files changed

+164
-18
lines changed

src/diffusers/utils/dummy_torch_and_transformers_objects.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,6 +1277,21 @@ def from_pretrained(cls, *args, **kwargs):
12771277
requires_backends(cls, ["torch", "transformers"])
12781278

12791279

1280+
class PyramidAttentionBroadcastConfig(metaclass=DummyObject):
1281+
_backends = ["torch", "transformers"]
1282+
1283+
def __init__(self, *args, **kwargs):
1284+
requires_backends(self, ["torch", "transformers"])
1285+
1286+
@classmethod
1287+
def from_config(cls, *args, **kwargs):
1288+
requires_backends(cls, ["torch", "transformers"])
1289+
1290+
@classmethod
1291+
def from_pretrained(cls, *args, **kwargs):
1292+
requires_backends(cls, ["torch", "transformers"])
1293+
1294+
12801295
class ReduxImageEncoder(metaclass=DummyObject):
12811296
_backends = ["torch", "transformers"]
12821297

@@ -2535,3 +2550,11 @@ def from_config(cls, *args, **kwargs):
25352550
@classmethod
25362551
def from_pretrained(cls, *args, **kwargs):
25372552
requires_backends(cls, ["torch", "transformers"])
2553+
2554+
2555+
def apply_pyramid_attention_broadcast(*args, **kwargs):
2556+
requires_backends(apply_pyramid_attention_broadcast, ["torch", "transformers"])
2557+
2558+
2559+
def apply_pyramid_attention_broadcast_on_module(*args, **kwargs):
2560+
requires_backends(apply_pyramid_attention_broadcast_on_module, ["torch", "transformers"])

tests/pipelines/allegro/test_allegro.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@
3030
)
3131

3232
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
33-
from ..test_pipelines_common import PipelineTesterMixin, to_np
33+
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
3434

3535

3636
enable_full_determinism()
3737

3838

39-
class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
39+
class AllegroPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
4040
pipeline_class = AllegroPipeline
4141
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
4242
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
@@ -54,14 +54,14 @@ class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
5454
)
5555
test_xformers_attention = False
5656

57-
def get_dummy_components(self):
57+
def get_dummy_components(self, num_layers: int = 1):
5858
torch.manual_seed(0)
5959
transformer = AllegroTransformer3DModel(
6060
num_attention_heads=2,
6161
attention_head_dim=12,
6262
in_channels=4,
6363
out_channels=4,
64-
num_layers=1,
64+
num_layers=num_layers,
6565
cross_attention_dim=24,
6666
sample_width=8,
6767
sample_height=8,

tests/pipelines/cogvideo/test_cogvideox.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
3333
from ..test_pipelines_common import (
3434
PipelineTesterMixin,
35+
PyramidAttentionBroadcastTesterMixin,
3536
check_qkv_fusion_matches_attn_procs_length,
3637
check_qkv_fusion_processors_exist,
3738
to_np,
@@ -41,7 +42,7 @@
4142
enable_full_determinism()
4243

4344

44-
class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
45+
class CogVideoXPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
4546
pipeline_class = CogVideoXPipeline
4647
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
4748
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
@@ -59,7 +60,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
5960
)
6061
test_xformers_attention = False
6162

62-
def get_dummy_components(self):
63+
def get_dummy_components(self, num_layers: int = 1):
6364
torch.manual_seed(0)
6465
transformer = CogVideoXTransformer3DModel(
6566
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings
@@ -71,7 +72,7 @@ def get_dummy_components(self):
7172
out_channels=4,
7273
time_embed_dim=2,
7374
text_embed_dim=32, # Must match with tiny-random-t5
74-
num_layers=1,
75+
num_layers=num_layers,
7576
sample_width=2, # latent width: 2 -> final width: 16
7677
sample_height=2, # latent height: 2 -> final height: 16
7778
sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9

tests/pipelines/flux/test_pipeline_flux.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,28 @@
1818
from ..test_pipelines_common import (
1919
FluxIPAdapterTesterMixin,
2020
PipelineTesterMixin,
21+
PyramidAttentionBroadcastTesterMixin,
2122
check_qkv_fusion_matches_attn_procs_length,
2223
check_qkv_fusion_processors_exist,
2324
)
2425

2526

26-
class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin):
27+
class FluxPipelineFastTests(
28+
unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin, PyramidAttentionBroadcastTesterMixin
29+
):
2730
pipeline_class = FluxPipeline
2831
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
2932
batch_params = frozenset(["prompt"])
3033

3134
# there is no xformers processor for Flux
3235
test_xformers_attention = False
3336

34-
def get_dummy_components(self):
37+
def get_dummy_components(self, num_layers: int = 1):
3538
torch.manual_seed(0)
3639
transformer = FluxTransformer2DModel(
3740
patch_size=1,
3841
in_channels=4,
39-
num_layers=1,
42+
num_layers=num_layers,
4043
num_single_layers=1,
4144
attention_head_dim=16,
4245
num_attention_heads=2,

tests/pipelines/hunyuan_video/test_hunyuan_video.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@
3030
torch_device,
3131
)
3232

33-
from ..test_pipelines_common import PipelineTesterMixin, to_np
33+
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
3434

3535

3636
enable_full_determinism()
3737

3838

39-
class HunyuanVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
39+
class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
4040
pipeline_class = HunyuanVideoPipeline
4141
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
4242
batch_params = frozenset(["prompt"])
@@ -54,14 +54,14 @@ class HunyuanVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
5454
# there is no xformers processor for Flux
5555
test_xformers_attention = False
5656

57-
def get_dummy_components(self):
57+
def get_dummy_components(self, num_layers: int = 1):
5858
torch.manual_seed(0)
5959
transformer = HunyuanVideoTransformer3DModel(
6060
in_channels=4,
6161
out_channels=4,
6262
num_attention_heads=2,
6363
attention_head_dim=10,
64-
num_layers=1,
64+
num_layers=num_layers,
6565
num_single_layers=1,
6666
num_refiner_layers=1,
6767
patch_size=1,

tests/pipelines/latte/test_latte.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
DDIMScheduler,
2828
LattePipeline,
2929
LatteTransformer3DModel,
30+
PyramidAttentionBroadcastConfig,
3031
)
3132
from diffusers.utils.import_utils import is_xformers_available
3233
from diffusers.utils.testing_utils import (
@@ -38,13 +39,13 @@
3839
)
3940

4041
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
41-
from ..test_pipelines_common import PipelineTesterMixin, to_np
42+
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
4243

4344

4445
enable_full_determinism()
4546

4647

47-
class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
48+
class LattePipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
4849
pipeline_class = LattePipeline
4950
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
5051
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
@@ -53,11 +54,23 @@ class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
5354

5455
required_optional_params = PipelineTesterMixin.required_optional_params
5556

56-
def get_dummy_components(self):
57+
pab_config = PyramidAttentionBroadcastConfig(
58+
spatial_attention_block_skip_range=2,
59+
temporal_attention_block_skip_range=2,
60+
cross_attention_block_skip_range=2,
61+
spatial_attention_timestep_skip_range=(100, 700),
62+
temporal_attention_timestep_skip_range=(100, 800),
63+
cross_attention_timestep_skip_range=(100, 800),
64+
spatial_attention_block_identifiers=["transformer_blocks"],
65+
temporal_attention_block_identifiers=["temporal_transformer_blocks"],
66+
cross_attention_block_identifiers=["transformer_blocks"],
67+
)
68+
69+
def get_dummy_components(self, num_layers: int = 1):
5770
torch.manual_seed(0)
5871
transformer = LatteTransformer3DModel(
5972
sample_size=8,
60-
num_layers=1,
73+
num_layers=num_layers,
6174
patch_size=2,
6275
attention_head_dim=8,
6376
num_attention_heads=3,

tests/pipelines/test_pipelines_common.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@
2424
DDIMScheduler,
2525
DiffusionPipeline,
2626
KolorsPipeline,
27+
PyramidAttentionBroadcastConfig,
2728
StableDiffusionPipeline,
2829
StableDiffusionXLPipeline,
2930
UNet2DConditionModel,
31+
apply_pyramid_attention_broadcast,
3032
)
3133
from diffusers.image_processor import VaeImageProcessor
3234
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
@@ -36,6 +38,7 @@
3638
from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet
3739
from diffusers.models.unets.unet_motion_model import UNetMotionModel
3840
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
41+
from diffusers.pipelines.pyramid_attention_broadcast_utils import PyramidAttentionBroadcastHook
3942
from diffusers.schedulers import KarrasDiffusionSchedulers
4043
from diffusers.utils import logging
4144
from diffusers.utils.import_utils import is_xformers_available
@@ -2271,6 +2274,109 @@ def _test_save_load_optional_components(self, expected_max_difference=1e-4):
22712274
self.assertLess(max_diff, expected_max_difference)
22722275

22732276

2277+
class PyramidAttentionBroadcastTesterMixin:
2278+
pab_config = PyramidAttentionBroadcastConfig(
2279+
spatial_attention_block_skip_range=2,
2280+
spatial_attention_timestep_skip_range=(100, 800),
2281+
spatial_attention_block_identifiers=["transformer_blocks"],
2282+
)
2283+
2284+
def test_pyramid_attention_broadcast_layers(self):
2285+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
2286+
num_layers = 2
2287+
components = self.get_dummy_components(num_layers=num_layers)
2288+
pipe = self.pipeline_class(**components)
2289+
pipe.set_progress_bar_config(disable=None)
2290+
2291+
apply_pyramid_attention_broadcast(pipe, self.pab_config)
2292+
2293+
expected_hooks = 0
2294+
if self.pab_config.spatial_attention_block_skip_range is not None:
2295+
expected_hooks += num_layers
2296+
if self.pab_config.temporal_attention_block_skip_range is not None:
2297+
expected_hooks += num_layers
2298+
if self.pab_config.cross_attention_block_skip_range is not None:
2299+
expected_hooks += num_layers
2300+
2301+
denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
2302+
count = 0
2303+
for module in denoiser.modules():
2304+
if hasattr(module, "_diffusers_hook"):
2305+
count += 1
2306+
self.assertTrue(
2307+
isinstance(module._diffusers_hook, PyramidAttentionBroadcastHook),
2308+
"Hook should be of type PyramidAttentionBroadcastHook.",
2309+
)
2310+
self.assertTrue(
2311+
hasattr(module, "_pyramid_attention_broadcast_state"),
2312+
"PAB state should be initialized when enabled.",
2313+
)
2314+
self.assertTrue(
2315+
module._pyramid_attention_broadcast_state.cache is None, "Cache should be None at initialization."
2316+
)
2317+
self.assertEqual(count, expected_hooks, "Number of hooks should match the expected number.")
2318+
2319+
# Perform dummy inference step to ensure state is updated
2320+
def pab_state_check_callback(pipe, i, t, kwargs):
2321+
for module in denoiser.modules():
2322+
if hasattr(module, "_diffusers_hook"):
2323+
self.assertTrue(
2324+
module._pyramid_attention_broadcast_state.cache is not None,
2325+
"Cache should have updated during inference.",
2326+
)
2327+
self.assertTrue(
2328+
module._pyramid_attention_broadcast_state.iteration == i + 1,
2329+
"Hook iteration state should have updated during inference.",
2330+
)
2331+
return {}
2332+
2333+
inputs = self.get_dummy_inputs(device)
2334+
inputs["num_inference_steps"] = 2
2335+
inputs["callback_on_step_end"] = pab_state_check_callback
2336+
pipe(**inputs)[0]
2337+
2338+
# After inference, reset_stateful_hooks is called within the pipeline, which should have reset the states
2339+
for module in denoiser.modules():
2340+
if hasattr(module, "_diffusers_hook"):
2341+
self.assertTrue(
2342+
module._pyramid_attention_broadcast_state.cache is None,
2343+
"Cache should be reset to None after inference.",
2344+
)
2345+
self.assertTrue(
2346+
module._pyramid_attention_broadcast_state.iteration == 0,
2347+
"Iteration should be reset to 0 after inference.",
2348+
)
2349+
2350+
def test_pyramid_attention_broadcast_inference(self, expected_atol: float = 0.2):
2351+
# We need to use higher tolerance because we are using a random model. With a converged/trained
2352+
# model, the tolerance can be lower.
2353+
2354+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
2355+
num_layers = 2
2356+
components = self.get_dummy_components(num_layers=num_layers)
2357+
pipe = self.pipeline_class(**components)
2358+
pipe = pipe.to(device)
2359+
pipe.set_progress_bar_config(disable=None)
2360+
2361+
inputs = self.get_dummy_inputs(device)
2362+
inputs["num_inference_steps"] = 4
2363+
output = pipe(**inputs)[0]
2364+
original_image_slice = output.flatten()
2365+
original_image_slice = np.concatenate((original_image_slice[:8], original_image_slice[-8:]))
2366+
2367+
apply_pyramid_attention_broadcast(pipe, self.pab_config)
2368+
2369+
inputs = self.get_dummy_inputs(device)
2370+
inputs["num_inference_steps"] = 4
2371+
output = pipe(**inputs)[0]
2372+
image_slice_pab_enabled = output.flatten()
2373+
image_slice_pab_enabled = np.concatenate((image_slice_pab_enabled[:8], image_slice_pab_enabled[-8:]))
2374+
2375+
assert np.allclose(
2376+
original_image_slice, image_slice_pab_enabled, atol=expected_atol
2377+
), "PAB outputs should not differ much in specified timestep range."
2378+
2379+
22742380
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
22752381
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
22762382
# reference image.

0 commit comments

Comments
 (0)