|
18 | 18 |
|
19 | 19 | import torch.nn as nn |
20 | 20 |
|
21 | | -from ..models.attention_processor import Attention |
| 21 | +from ..models.attention_processor import Attention, MochiAttention |
22 | 22 | from ..models.hooks import ModelHook, add_hook_to_module |
23 | 23 | from ..utils import logging |
24 | 24 | from .pipeline_utils import DiffusionPipeline |
|
27 | 27 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
28 | 28 |
|
29 | 29 |
|
30 | | -_ATTENTION_CLASSES = (Attention,) |
| 30 | +_ATTENTION_CLASSES = (Attention, MochiAttention) |
31 | 31 |
|
32 | 32 | _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks") |
33 | 33 | _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) |
@@ -175,8 +175,10 @@ def apply_pyramid_attention_broadcast( |
175 | 175 | for name, module in denoiser.named_modules(): |
176 | 176 | if not isinstance(module, _ATTENTION_CLASSES): |
177 | 177 | continue |
178 | | - if isinstance(module, Attention): |
| 178 | + if isinstance(module, (Attention)): |
179 | 179 | _apply_pyramid_attention_broadcast_on_attention_class(pipeline, name, module, config) |
| 180 | + if isinstance(module, MochiAttention): |
| 181 | + _apply_pyramid_attention_broadcast_on_mochi_attention_class(pipeline, name, module, config) |
180 | 182 |
|
181 | 183 |
|
182 | 184 | def apply_pyramid_attention_broadcast_on_module( |
@@ -263,6 +265,13 @@ def skip_callback(module: nn.Module) -> bool: |
263 | 265 | return True |
264 | 266 |
|
265 | 267 |
|
| 268 | +def _apply_pyramid_attention_broadcast_on_mochi_attention_class( |
| 269 | + pipeline: DiffusionPipeline, name: str, module: MochiAttention, config: PyramidAttentionBroadcastConfig |
| 270 | +) -> bool: |
| 271 | + # The same logic as Attention class works here, so just use that for now |
| 272 | + return _apply_pyramid_attention_broadcast_on_attention_class(pipeline, name, module, config) |
| 273 | + |
| 274 | + |
266 | 275 | class PyramidAttentionBroadcastHook(ModelHook): |
267 | 276 | r"""A hook that applies Pyramid Attention Broadcast to a given module.""" |
268 | 277 |
|
|
0 commit comments