Skip to content

Commit bb250d6

Browse files
committed
update
1 parent 62b5b8d commit bb250d6

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,8 @@ def __init__(
930930
self.out_dim = out_dim if out_dim is not None else query_dim
931931
self.out_context_dim = out_context_dim if out_context_dim else query_dim
932932
self.context_pre_only = context_pre_only
933+
# TODO(aryan): Maybe try to improve the checks in PAB instead
934+
self.is_cross_attention = False
933935

934936
self.heads = out_dim // dim_head if out_dim is not None else heads
935937

src/diffusers/pipelines/pyramid_attention_broadcast_utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import torch.nn as nn
2020

21-
from ..models.attention_processor import Attention
21+
from ..models.attention_processor import Attention, MochiAttention
2222
from ..models.hooks import ModelHook, add_hook_to_module
2323
from ..utils import logging
2424
from .pipeline_utils import DiffusionPipeline
@@ -27,7 +27,7 @@
2727
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2828

2929

30-
_ATTENTION_CLASSES = (Attention,)
30+
_ATTENTION_CLASSES = (Attention, MochiAttention)
3131

3232
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
3333
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
@@ -175,8 +175,10 @@ def apply_pyramid_attention_broadcast(
175175
for name, module in denoiser.named_modules():
176176
if not isinstance(module, _ATTENTION_CLASSES):
177177
continue
178-
if isinstance(module, Attention):
178+
if isinstance(module, (Attention)):
179179
_apply_pyramid_attention_broadcast_on_attention_class(pipeline, name, module, config)
180+
if isinstance(module, MochiAttention):
181+
_apply_pyramid_attention_broadcast_on_mochi_attention_class(pipeline, name, module, config)
180182

181183

182184
def apply_pyramid_attention_broadcast_on_module(
@@ -263,6 +265,13 @@ def skip_callback(module: nn.Module) -> bool:
263265
return True
264266

265267

268+
def _apply_pyramid_attention_broadcast_on_mochi_attention_class(
269+
pipeline: DiffusionPipeline, name: str, module: MochiAttention, config: PyramidAttentionBroadcastConfig
270+
) -> bool:
271+
# The same logic as Attention class works here, so just use that for now
272+
return _apply_pyramid_attention_broadcast_on_attention_class(pipeline, name, module, config)
273+
274+
266275
class PyramidAttentionBroadcastHook(ModelHook):
267276
r"""A hook that applies Pyramid Attention Broadcast to a given module."""
268277

0 commit comments

Comments
 (0)