Skip to content
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
d9e7372
init
a-r-r-o-w Dec 9, 2024
9a732f0
update
a-r-r-o-w Dec 9, 2024
80c5acd
update
a-r-r-o-w Dec 9, 2024
6047114
update
a-r-r-o-w Dec 27, 2024
82d85bd
make style
a-r-r-o-w Dec 27, 2024
02a2e0d
Merge branch 'main' into fastercache
a-r-r-o-w Dec 27, 2024
5359222
update
a-r-r-o-w Dec 29, 2024
c02f72d
fix
a-r-r-o-w Dec 29, 2024
c723d5c
Merge branch 'main' into fastercache
a-r-r-o-w Dec 31, 2024
30d9aaf
make it work with guidance distilled models
a-r-r-o-w Jan 1, 2025
07edfa9
update
a-r-r-o-w Jan 2, 2025
436b772
make fix-copies
a-r-r-o-w Jan 2, 2025
d68977d
add tests
a-r-r-o-w Jan 2, 2025
f3cb80c
update
a-r-r-o-w Jan 2, 2025
3c498ef
apply_faster_cache -> apply_fastercache
a-r-r-o-w Jan 2, 2025
4996dfd
fix
a-r-r-o-w Jan 2, 2025
04874a7
reorder
a-r-r-o-w Jan 2, 2025
4a6e62f
Merge branch 'main' into fastercache
a-r-r-o-w Jan 28, 2025
6de34fe
update
a-r-r-o-w Jan 28, 2025
d98473d
refactor
a-r-r-o-w Jan 28, 2025
93de5f3
update docs
a-r-r-o-w Jan 28, 2025
ea18eb6
add fastercache to CacheMixin
a-r-r-o-w Jan 28, 2025
f92f45e
update tests
a-r-r-o-w Jan 28, 2025
c60c72e
Merge branch 'main' into fastercache
a-r-r-o-w Jan 28, 2025
251ade1
Apply suggestions from code review
a-r-r-o-w Jan 28, 2025
fa9a1f3
make style
a-r-r-o-w Jan 28, 2025
7ad7cc8
try to fix partial import error
a-r-r-o-w Jan 28, 2025
4c75017
Merge branch 'main' into fastercache
a-r-r-o-w Feb 11, 2025
063e489
Merge branch 'main' into fastercache
a-r-r-o-w Feb 16, 2025
dfb62fb
Merge branch 'main' into fastercache
a-r-r-o-w Mar 19, 2025
a20e846
Apply style fixes
github-actions[bot] Mar 19, 2025
2a34215
raise warning
a-r-r-o-w Mar 21, 2025
6181dea
Merge branch 'main' into fastercache
a-r-r-o-w Mar 21, 2025
4a4bab8
update
a-r-r-o-w Mar 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions docs/source/en/api/cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,33 @@ config = PyramidAttentionBroadcastConfig(
pipe.transformer.enable_cache(config)
```

## Faster Cache

[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong.

FasterCache is a method that speeds up inference in diffusion transformers by:
- Reusing attention states between successive inference steps, due to high similarity between them
- Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional branch output using the conditional branch output

```python
import torch
from diffusers import CogVideoXPipeline, FasterCacheConfig

pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")

config = FasterCacheConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(-1, 681),
current_timestep_callback=lambda: pipe.current_timestep,
attention_weight_callback=lambda _: 0.3,
unconditional_batch_skip_range=5,
unconditional_batch_timestep_skip_range=(-1, 781),
tensor_format="BFCHW",
)
pipe.transformer.enable_cache(config)
```

### CacheMixin

[[autodoc]] CacheMixin
Expand All @@ -47,3 +74,9 @@ pipe.transformer.enable_cache(config)
[[autodoc]] PyramidAttentionBroadcastConfig

[[autodoc]] apply_pyramid_attention_broadcast

### FasterCacheConfig

[[autodoc]] FasterCacheConfig

[[autodoc]] apply_faster_cache
10 changes: 9 additions & 1 deletion src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,10 @@
else:
_import_structure["hooks"].extend(
[
"FasterCacheConfig",
"HookRegistry",
"PyramidAttentionBroadcastConfig",
"apply_faster_cache",
"apply_pyramid_attention_broadcast",
]
)
Expand Down Expand Up @@ -703,7 +705,13 @@
except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403
else:
from .hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .hooks import (
FasterCacheConfig,
HookRegistry,
PyramidAttentionBroadcastConfig,
apply_faster_cache,
apply_pyramid_attention_broadcast,
)
from .models import (
AllegroTransformer3DModel,
AsymmetricAutoencoderKL,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


if is_torch_available():
from .faster_cache import FasterCacheConfig, apply_faster_cache
from .group_offloading import apply_group_offloading
from .hooks import HookRegistry, ModelHook
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
Expand Down
647 changes: 647 additions & 0 deletions src/diffusers/hooks/faster_cache.py

Large diffs are not rendered by default.

11 changes: 4 additions & 7 deletions src/diffusers/hooks/pyramid_attention_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


_PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast"
_ATTENTION_CLASSES = (Attention, MochiAttention)

_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
Expand Down Expand Up @@ -87,7 +87,7 @@ class PyramidAttentionBroadcastConfig:

def __repr__(self) -> str:
return (
f"PyramidAttentionBroadcastConfig("
f"PyramidAttentionBroadcastConfig(\n"
f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
f" cross_attention_block_skip_range={self.cross_attention_block_skip_range},\n"
Expand Down Expand Up @@ -175,10 +175,7 @@ def reset_state(self, module: torch.nn.Module) -> None:
return module


def apply_pyramid_attention_broadcast(
module: torch.nn.Module,
config: PyramidAttentionBroadcastConfig,
):
def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAttentionBroadcastConfig):
r"""
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline.

Expand Down Expand Up @@ -311,4 +308,4 @@ def _apply_pyramid_attention_broadcast_hook(
"""
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback)
registry.register_hook(hook, "pyramid_attention_broadcast")
registry.register_hook(hook, _PYRAMID_ATTENTION_BROADCAST_HOOK)
25 changes: 22 additions & 3 deletions src/diffusers/models/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class CacheMixin:

Supported caching techniques:
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
- [FasterCache](https://huggingface.co/papers/2410.19355)
"""

_cache_config = None
Expand Down Expand Up @@ -59,25 +60,43 @@ def enable_cache(self, config) -> None:
```
"""

from ..hooks import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from ..hooks import (
FasterCacheConfig,
PyramidAttentionBroadcastConfig,
apply_faster_cache,
apply_pyramid_attention_broadcast,
)

if self.is_cache_enabled:
raise ValueError(
f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
)

if isinstance(config, PyramidAttentionBroadcastConfig):
apply_pyramid_attention_broadcast(self, config)
elif isinstance(config, FasterCacheConfig):
apply_faster_cache(self, config)
else:
raise ValueError(f"Cache config {type(config)} is not supported.")

self._cache_config = config

def disable_cache(self) -> None:
from ..hooks import HookRegistry, PyramidAttentionBroadcastConfig
from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK

if self._cache_config is None:
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
return

if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry = HookRegistry.check_if_exists_or_initialize(self)
registry.remove_hook("pyramid_attention_broadcast", recurse=True)
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
elif isinstance(self._cache_config, FasterCacheConfig):
registry = HookRegistry.check_if_exists_or_initialize(self)
registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
else:
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")

Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
" `from_numpy` is no longer required."
" Pass `output_type='pt' to use the new version now."
)
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
deprecate("output_type=='np'", "0.34.0", deprecation_message, standard_warn=False)
return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ def enable_layerwise_casting(
non_blocking (`bool`, *optional*, defaults to `False`):
If `True`, the weight casting operations are non-blocking.
"""
from ..hooks import apply_layerwise_casting

user_provided_patterns = True
if skip_modules_pattern is None:
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/latte/pipeline_latte.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ def __call__(

# predict noise model_output
noise_pred = self.transformer(
latent_model_input,
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=current_timestep,
enable_temporal_attentions=enable_temporal_attentions,
Expand Down
19 changes: 19 additions & 0 deletions src/diffusers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,21 @@
from ..utils import DummyObject, requires_backends


class FasterCacheConfig(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class HookRegistry(metaclass=DummyObject):
_backends = ["torch"]

Expand Down Expand Up @@ -32,6 +47,10 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


def apply_faster_cache(*args, **kwargs):
requires_backends(apply_faster_cache, ["torch"])


def apply_pyramid_attention_broadcast(*args, **kwargs):
requires_backends(apply_pyramid_attention_broadcast, ["torch"])

Expand Down
5 changes: 4 additions & 1 deletion tests/pipelines/cogvideo/test_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
FasterCacheTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
Expand All @@ -42,7 +43,9 @@
enable_full_determinism()


class CogVideoXPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
class CogVideoXPipelineFastTests(
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
):
pipeline_class = CogVideoXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
Expand Down
23 changes: 21 additions & 2 deletions tests/pipelines/flux/test_pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel

from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
from diffusers import (
AutoencoderKL,
FasterCacheConfig,
FlowMatchEulerDiscreteScheduler,
FluxPipeline,
FluxTransformer2DModel,
)
from diffusers.utils.testing_utils import (
backend_empty_cache,
nightly,
Expand All @@ -18,6 +24,7 @@
)

from ..test_pipelines_common import (
FasterCacheTesterMixin,
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
Expand All @@ -27,7 +34,11 @@


class FluxPipelineFastTests(
unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin, PyramidAttentionBroadcastTesterMixin
unittest.TestCase,
PipelineTesterMixin,
FluxIPAdapterTesterMixin,
PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin,
):
pipeline_class = FluxPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
Expand All @@ -38,6 +49,14 @@ class FluxPipelineFastTests(
test_layerwise_casting = True
test_group_offloading = True

faster_cache_config = FasterCacheConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(-1, 901),
unconditional_batch_skip_range=2,
attention_weight_callback=lambda _: 0.5,
is_guidance_distilled=True,
)

def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0)
transformer = FluxTransformer2DModel(
Expand Down
20 changes: 18 additions & 2 deletions tests/pipelines/hunyuan_video/test_hunyuan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from diffusers import (
AutoencoderKLHunyuanVideo,
FasterCacheConfig,
FlowMatchEulerDiscreteScheduler,
HunyuanVideoPipeline,
HunyuanVideoTransformer3DModel,
Expand All @@ -30,13 +31,20 @@
torch_device,
)

from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
from ..test_pipelines_common import (
FasterCacheTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
to_np,
)


enable_full_determinism()


class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
class HunyuanVideoPipelineFastTests(
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
):
pipeline_class = HunyuanVideoPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"])
Expand All @@ -56,6 +64,14 @@ class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadca
test_layerwise_casting = True
test_group_offloading = True

faster_cache_config = FasterCacheConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(-1, 901),
unconditional_batch_skip_range=2,
attention_weight_callback=lambda _: 0.5,
is_guidance_distilled=True,
)

def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0)
transformer = HunyuanVideoTransformer3DModel(
Expand Down
21 changes: 19 additions & 2 deletions tests/pipelines/latte/test_latte.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from diffusers import (
AutoencoderKL,
DDIMScheduler,
FasterCacheConfig,
LattePipeline,
LatteTransformer3DModel,
PyramidAttentionBroadcastConfig,
Expand All @@ -40,13 +41,20 @@
)

from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
from ..test_pipelines_common import (
FasterCacheTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
to_np,
)


enable_full_determinism()


class LattePipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
class LattePipelineFastTests(
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
):
pipeline_class = LattePipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
Expand All @@ -69,6 +77,15 @@ class LattePipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTeste
cross_attention_block_identifiers=["transformer_blocks"],
)

faster_cache_config = FasterCacheConfig(
spatial_attention_block_skip_range=2,
temporal_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(-1, 901),
temporal_attention_timestep_skip_range=(-1, 901),
unconditional_batch_skip_range=2,
attention_weight_callback=lambda _: 0.5,
)

def get_dummy_components(self, num_layers: int = 1):
torch.manual_seed(0)
transformer = LatteTransformer3DModel(
Expand Down
Loading
Loading