Skip to content

Commit f92f45e

Browse files
committed
update tests
1 parent ea18eb6 commit f92f45e

File tree

6 files changed

+110
-83
lines changed

6 files changed

+110
-83
lines changed

src/diffusers/hooks/faster_cache.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,12 @@
3131
_FASTER_CACHE_BLOCK_HOOK = "faster_cache_block"
3232
_ATTENTION_CLASSES = (Attention, MochiAttention)
3333
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
34-
"blocks.*attn",
35-
"transformer_blocks.*attn",
36-
"single_transformer_blocks.*attn",
34+
"^blocks.*attn",
35+
"^transformer_blocks.*attn",
36+
"^single_transformer_blocks.*attn"
3737
)
38-
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks.*attn",)
38+
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",)
39+
_TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
3940
_UNCOND_COND_INPUT_KWARGS_IDENTIFIERS = (
4041
"hidden_states",
4142
"encoder_hidden_states",
@@ -276,9 +277,10 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
276277
self.state.iteration > 0
277278
and is_within_timestep_range
278279
and self.state.iteration % self.unconditional_batch_skip_range != 0
280+
and not self.is_guidance_distilled
279281
)
280282

281-
if should_skip_uncond and not self.is_guidance_distilled:
283+
if should_skip_uncond:
282284
is_any_kwarg_uncond = any(k in self.uncond_cond_input_kwargs_identifiers for k in kwargs.keys())
283285
if is_any_kwarg_uncond:
284286
logger.debug("FasterCache - Skipping unconditional branch computation")
@@ -483,7 +485,7 @@ def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
483485

484486
def apply_faster_cache(
485487
module: torch.nn.Module,
486-
config: Optional[FasterCacheConfig] = None,
488+
config: FasterCacheConfig
487489
) -> None:
488490
r"""
489491
Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
@@ -515,10 +517,6 @@ def apply_faster_cache(
515517
```
516518
"""
517519

518-
if config is None:
519-
logger.warning("No FasterCacheConfig provided. Using default configuration.")
520-
config = FasterCacheConfig()
521-
522520
if config.attention_weight_callback is None:
523521
# If the user has not provided a weight callback, we default to 0.5 for all timesteps.
524522
# In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but
@@ -568,7 +566,8 @@ def high_frequency_weight_callback(module: torch.nn.Module) -> float:
568566
for name, submodule in module.named_modules():
569567
if not isinstance(submodule, _ATTENTION_CLASSES):
570568
continue
571-
_apply_faster_cache_on_attention_class(name, submodule, config)
569+
if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
570+
_apply_faster_cache_on_attention_class(name, submodule, config)
572571

573572

574573
def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCacheConfig) -> None:
@@ -590,13 +589,10 @@ def _apply_faster_cache_on_attention_class(name: str, module: Attention, config:
590589
is_spatial_self_attention = (
591590
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
592591
and config.spatial_attention_block_skip_range is not None
593-
and not module.is_cross_attention
592+
and not getattr(module, "is_cross_attention", False)
594593
)
595594
is_temporal_self_attention = (
596-
any(
597-
f"{identifier}." in name or identifier == name
598-
for identifier in config.temporal_attention_block_identifiers
599-
)
595+
any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers)
600596
and config.temporal_attention_block_skip_range is not None
601597
and not module.is_cross_attention
602598
)
@@ -633,7 +629,7 @@ def _apply_faster_cache_on_attention_class(name: str, module: Attention, config:
633629
registry.register_hook(hook, _FASTER_CACHE_BLOCK_HOOK)
634630

635631

636-
# Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/fastercache_sample_latte.py#L127C1-L143C39
632+
# Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/faster_cache_sample_latte.py#L127C1-L143C39
637633
@torch.no_grad()
638634
def _split_low_high_freq(x):
639635
fft = torch.fft.fft2(x)

src/diffusers/hooks/pyramid_attention_broadcast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def reset_state(self, module: torch.nn.Module) -> None:
177177

178178
def apply_pyramid_attention_broadcast(
179179
module: torch.nn.Module,
180-
config: PyramidAttentionBroadcastConfig,
180+
config: PyramidAttentionBroadcastConfig
181181
):
182182
r"""
183183
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline.

tests/pipelines/flux/test_pipeline_flux.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77
from huggingface_hub import hf_hub_download
88
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
99

10-
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
10+
from diffusers import (
11+
AutoencoderKL,
12+
FasterCacheConfig,
13+
FlowMatchEulerDiscreteScheduler,
14+
FluxPipeline,
15+
FluxTransformer2DModel,
16+
)
1117
from diffusers.utils.testing_utils import (
1218
nightly,
1319
numpy_cosine_similarity_distance,
@@ -41,6 +47,14 @@ class FluxPipelineFastTests(
4147
test_xformers_attention = False
4248
test_layerwise_casting = True
4349

50+
faster_cache_config = FasterCacheConfig(
51+
spatial_attention_block_skip_range=2,
52+
spatial_attention_timestep_skip_range=(-1, 901),
53+
unconditional_batch_skip_range=2,
54+
attention_weight_callback=lambda _: 0.5,
55+
is_guidance_distilled=True,
56+
)
57+
4458
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
4559
torch.manual_seed(0)
4660
transformer = FluxTransformer2DModel(

tests/pipelines/hunyuan_video/test_hunyuan_video.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from diffusers import (
2323
AutoencoderKLHunyuanVideo,
24+
FasterCacheConfig,
2425
FlowMatchEulerDiscreteScheduler,
2526
HunyuanVideoPipeline,
2627
HunyuanVideoTransformer3DModel,
@@ -30,13 +31,20 @@
3031
torch_device,
3132
)
3233

33-
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
34+
from ..test_pipelines_common import (
35+
FasterCacheTesterMixin,
36+
PipelineTesterMixin,
37+
PyramidAttentionBroadcastTesterMixin,
38+
to_np,
39+
)
3440

3541

3642
enable_full_determinism()
3743

3844

39-
class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
45+
class HunyuanVideoPipelineFastTests(
46+
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
47+
):
4048
pipeline_class = HunyuanVideoPipeline
4149
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
4250
batch_params = frozenset(["prompt"])
@@ -55,6 +63,14 @@ class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadca
5563
test_xformers_attention = False
5664
test_layerwise_casting = True
5765

66+
faster_cache_config = FasterCacheConfig(
67+
spatial_attention_block_skip_range=2,
68+
spatial_attention_timestep_skip_range=(-1, 901),
69+
unconditional_batch_skip_range=2,
70+
attention_weight_callback=lambda _: 0.5,
71+
is_guidance_distilled=True,
72+
)
73+
5874
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
5975
torch.manual_seed(0)
6076
transformer = HunyuanVideoTransformer3DModel(

tests/pipelines/latte/test_latte.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class LattePipelineFastTests(
7575
cross_attention_block_identifiers=["transformer_blocks"],
7676
)
7777

78-
fastercache_config = FasterCacheConfig(
78+
faster_cache_config = FasterCacheConfig(
7979
spatial_attention_block_skip_range=2,
8080
temporal_attention_block_skip_range=2,
8181
spatial_attention_timestep_skip_range=(-1, 901),

0 commit comments

Comments
 (0)