Skip to content

Commit 4476402

Browse files
authored
Merge branch 'main' into to-single-file/wan
2 parents 1dc9c65 + 6f3ac30 commit 4476402

20 files changed

+1148
-136
lines changed

scripts/convert_wan_to_diffusers.py

Lines changed: 419 additions & 5 deletions
Large diffs are not rendered by default.

src/diffusers/hooks/_common.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616

1717
import torch
1818

19-
from ..models.attention import FeedForward, LuminaFeedForward
19+
from ..models.attention import AttentionModuleMixin, FeedForward, LuminaFeedForward
2020
from ..models.attention_processor import Attention, MochiAttention
2121

2222

23-
_ATTENTION_CLASSES = (Attention, MochiAttention)
23+
_ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin)
2424
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
2525

2626
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
@@ -35,6 +35,19 @@
3535
}
3636
)
3737

38+
# Layers supported for group offloading and layerwise casting
39+
_GO_LC_SUPPORTED_PYTORCH_LAYERS = (
40+
torch.nn.Conv1d,
41+
torch.nn.Conv2d,
42+
torch.nn.Conv3d,
43+
torch.nn.ConvTranspose1d,
44+
torch.nn.ConvTranspose2d,
45+
torch.nn.ConvTranspose3d,
46+
torch.nn.Linear,
47+
# TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
48+
# because of double invocation of the same norm layer in CogVideoXLayerNorm
49+
)
50+
3851

3952
def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]:
4053
for submodule_name, submodule in module.named_modules():

src/diffusers/hooks/faster_cache.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
import torch
2020

2121
from ..models.attention import AttentionModuleMixin
22-
from ..models.attention_processor import Attention, MochiAttention
2322
from ..models.modeling_outputs import Transformer2DModelOutput
2423
from ..utils import logging
24+
from ._common import _ATTENTION_CLASSES
2525
from .hooks import HookRegistry, ModelHook
2626

2727

@@ -30,7 +30,6 @@
3030

3131
_FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser"
3232
_FASTER_CACHE_BLOCK_HOOK = "faster_cache_block"
33-
_ATTENTION_CLASSES = (Attention, MochiAttention)
3433
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
3534
"^blocks.*attn",
3635
"^transformer_blocks.*attn",
@@ -489,9 +488,10 @@ def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> No
489488
Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
490489
491490
Args:
492-
pipeline (`DiffusionPipeline`):
493-
The diffusion pipeline to apply FasterCache to.
494-
config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`):
491+
module (`torch.nn.Module`):
492+
The pytorch module to apply FasterCache to. Typically, this should be a transformer architecture supported
493+
in Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
494+
config (`FasterCacheConfig`):
495495
The configuration to use for FasterCache.
496496
497497
Example:
@@ -568,7 +568,7 @@ def high_frequency_weight_callback(module: torch.nn.Module) -> float:
568568
_apply_faster_cache_on_denoiser(module, config)
569569

570570
for name, submodule in module.named_modules():
571-
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
571+
if not isinstance(submodule, _ATTENTION_CLASSES):
572572
continue
573573
if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
574574
_apply_faster_cache_on_attention_class(name, submodule, config)
@@ -589,7 +589,7 @@ def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCache
589589
registry.register_hook(hook, _FASTER_CACHE_DENOISER_HOOK)
590590

591591

592-
def _apply_faster_cache_on_attention_class(name: str, module: Attention, config: FasterCacheConfig) -> None:
592+
def _apply_faster_cache_on_attention_class(name: str, module: AttentionModuleMixin, config: FasterCacheConfig) -> None:
593593
is_spatial_self_attention = (
594594
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
595595
and config.spatial_attention_block_skip_range is not None

src/diffusers/hooks/first_block_cache.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,38 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
192192

193193

194194
def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
195+
"""
196+
Applies [First Block
197+
Cache](https://github.com/chengzeyi/ParaAttention/blob/4de137c5b96416489f06e43e19f2c14a772e28fd/README.md#first-block-cache-our-dynamic-caching)
198+
to a given module.
199+
200+
First Block Cache builds on the ideas of [TeaCache](https://huggingface.co/papers/2411.19108). It is much simpler
201+
to implement generically for a wide range of models and has been integrated first for experimental purposes.
202+
203+
Args:
204+
module (`torch.nn.Module`):
205+
The pytorch module to apply FBCache to. Typically, this should be a transformer architecture supported in
206+
Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
207+
config (`FirstBlockCacheConfig`):
208+
The configuration to use for applying the FBCache method.
209+
210+
Example:
211+
```python
212+
>>> import torch
213+
>>> from diffusers import CogView4Pipeline
214+
>>> from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig
215+
216+
>>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
217+
>>> pipe.to("cuda")
218+
219+
>>> apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.2))
220+
221+
>>> prompt = "A photo of an astronaut riding a horse on mars"
222+
>>> image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0]
223+
>>> image.save("output.png")
224+
```
225+
"""
226+
195227
state_manager = StateManager(FBCSharedBlockState, (), {})
196228
remaining_blocks = []
197229

src/diffusers/hooks/group_offloading.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import torch
2424

2525
from ..utils import get_logger, is_accelerate_available
26+
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
2627
from .hooks import HookRegistry, ModelHook
2728

2829

@@ -39,13 +40,6 @@
3940
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
4041
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
4142
_GROUP_ID_LAZY_LEAF = "lazy_leafs"
42-
_SUPPORTED_PYTORCH_LAYERS = (
43-
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
44-
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
45-
torch.nn.Linear,
46-
# TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
47-
# because of double invocation of the same norm layer in CogVideoXLayerNorm
48-
)
4943
# fmt: on
5044

5145

@@ -367,7 +361,8 @@ def __init__(self):
367361
def initialize_hook(self, module):
368362
def make_execution_order_update_callback(current_name, current_submodule):
369363
def callback():
370-
logger.debug(f"Adding {current_name} to the execution order")
364+
if not torch.compiler.is_compiling():
365+
logger.debug(f"Adding {current_name} to the execution order")
371366
self.execution_order.append((current_name, current_submodule))
372367

373368
return callback
@@ -404,12 +399,13 @@ def post_forward(self, module, output):
404399
# if the missing layers end up being executed in the future.
405400
if execution_order_module_names != self._layer_execution_tracker_module_names:
406401
unexecuted_layers = list(self._layer_execution_tracker_module_names - execution_order_module_names)
407-
logger.warning(
408-
"It seems like some layers were not executed during the forward pass. This may lead to problems when "
409-
"applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please "
410-
"make sure that all layers are executed during the forward pass. The following layers were not executed:\n"
411-
f"{unexecuted_layers=}"
412-
)
402+
if not torch.compiler.is_compiling():
403+
logger.warning(
404+
"It seems like some layers were not executed during the forward pass. This may lead to problems when "
405+
"applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please "
406+
"make sure that all layers are executed during the forward pass. The following layers were not executed:\n"
407+
f"{unexecuted_layers=}"
408+
)
413409

414410
# Remove the layer execution tracker hooks from the submodules
415411
base_module_registry = module._diffusers_hook
@@ -437,7 +433,8 @@ def post_forward(self, module, output):
437433
for i in range(num_executed - 1):
438434
name1, _ = self.execution_order[i]
439435
name2, _ = self.execution_order[i + 1]
440-
logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}")
436+
if not torch.compiler.is_compiling():
437+
logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}")
441438
group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group
442439
group_offloading_hooks[i].next_group.onload_self = False
443440

@@ -680,7 +677,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
680677
# Create module groups for leaf modules and apply group offloading hooks
681678
modules_with_group_offloading = set()
682679
for name, submodule in module.named_modules():
683-
if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
680+
if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
684681
continue
685682
group = ModuleGroup(
686683
modules=[submodule],

src/diffusers/hooks/layerwise_casting.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919

2020
from ..utils import get_logger, is_peft_available, is_peft_version
21+
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
2122
from .hooks import HookRegistry, ModelHook
2223

2324

@@ -27,12 +28,6 @@
2728
# fmt: off
2829
_LAYERWISE_CASTING_HOOK = "layerwise_casting"
2930
_PEFT_AUTOCAST_DISABLE_HOOK = "peft_autocast_disable"
30-
SUPPORTED_PYTORCH_LAYERS = (
31-
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
32-
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
33-
torch.nn.Linear,
34-
)
35-
3631
DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm", "^proj_in$", "^proj_out$")
3732
# fmt: on
3833

@@ -186,7 +181,7 @@ def _apply_layerwise_casting(
186181
logger.debug(f'Skipping layerwise casting for layer "{_prefix}"')
187182
return
188183

189-
if isinstance(module, SUPPORTED_PYTORCH_LAYERS):
184+
if isinstance(module, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
190185
logger.debug(f'Applying layerwise casting to layer "{_prefix}"')
191186
apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking)
192187
return

src/diffusers/hooks/pyramid_attention_broadcast.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,19 @@
2121
from ..models.attention import AttentionModuleMixin
2222
from ..models.attention_processor import Attention, MochiAttention
2323
from ..utils import logging
24+
from ._common import (
25+
_ATTENTION_CLASSES,
26+
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
27+
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
28+
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
29+
)
2430
from .hooks import HookRegistry, ModelHook
2531

2632

2733
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2834

2935

3036
_PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast"
31-
_ATTENTION_CLASSES = (Attention, MochiAttention)
32-
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
33-
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
34-
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
3537

3638

3739
@dataclass
@@ -61,11 +63,11 @@ class PyramidAttentionBroadcastConfig:
6163
cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
6264
The range of timesteps to skip in the cross-attention layer. The attention computations will be
6365
conditionally skipped if the current timestep is within the specified range.
64-
spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
66+
spatial_attention_block_identifiers (`Tuple[str, ...]`):
6567
The identifiers to match against the layer names to determine if the layer is a spatial attention layer.
66-
temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks",)`):
68+
temporal_attention_block_identifiers (`Tuple[str, ...]`):
6769
The identifiers to match against the layer names to determine if the layer is a temporal attention layer.
68-
cross_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
70+
cross_attention_block_identifiers (`Tuple[str, ...]`):
6971
The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
7072
"""
7173

@@ -77,9 +79,9 @@ class PyramidAttentionBroadcastConfig:
7779
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
7880
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
7981

80-
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
81-
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
82-
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
82+
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS
83+
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS
84+
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS
8385

8486
current_timestep_callback: Callable[[], int] = None
8587

0 commit comments

Comments
 (0)