| 
12 | 12 | # See the License for the specific language governing permissions and  | 
13 | 13 | # limitations under the License.  | 
14 | 14 | 
 
  | 
15 |  | -from typing import List, Optional, Tuple  | 
 | 15 | +from dataclasses import dataclass  | 
 | 16 | +from typing import Callable, List, Optional, Tuple, Type, TypeVar  | 
16 | 17 | 
 
  | 
17 | 18 | import torch.nn as nn  | 
18 | 19 | 
 
  | 
19 | 20 | from ..models.attention_processor import Attention  | 
20 | 21 | from ..models.hooks import PyramidAttentionBroadcastHook, add_hook_to_module, remove_hook_from_module  | 
21 | 22 | from ..utils import logging  | 
 | 23 | +from .pipeline_utils import DiffusionPipeline  | 
22 | 24 | 
 
  | 
23 | 25 | 
 
  | 
24 | 26 | logger = logging.get_logger(__name__)  # pylint: disable=invalid-name  | 
25 | 27 | 
 
  | 
26 | 28 | 
 
  | 
 | 29 | +_ATTENTION_CLASSES = (Attention,)  | 
 | 30 | + | 
 | 31 | +_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ["blocks", "transformer_blocks"]  | 
 | 32 | +_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ["temporal_transformer_blocks"]  | 
 | 33 | +_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ["blocks", "transformer_blocks"]  | 
 | 34 | + | 
 | 35 | + | 
 | 36 | +@dataclass  | 
 | 37 | +class PyramidAttentionBroadcastConfig:  | 
 | 38 | +    spatial_attention_block_skip = None  | 
 | 39 | +    temporal_attention_block_skip = None  | 
 | 40 | +    cross_attention_block_skip = None  | 
 | 41 | +      | 
 | 42 | +    spatial_attention_timestep_skip_range = (100, 800)  | 
 | 43 | +    temporal_attention_timestep_skip_range = (100, 800)  | 
 | 44 | +    cross_attention_timestep_skip_range = (100, 800)  | 
 | 45 | + | 
 | 46 | +    spatial_attention_block_identifiers = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS  | 
 | 47 | +    temporal_attention_block_identifiers = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS  | 
 | 48 | +    cross_attention_block_identifiers = _CROSS_ATTENTION_BLOCK_IDENTIFIERS  | 
 | 49 | + | 
 | 50 | + | 
 | 51 | +class PyramidAttentionBroadcastState:  | 
 | 52 | +    iteration = 0  | 
 | 53 | + | 
 | 54 | + | 
 | 55 | +def apply_pyramid_attention_broadcast(  | 
 | 56 | +    pipeline: DiffusionPipeline,  | 
 | 57 | +    config: Optional[PyramidAttentionBroadcastConfig] = None,  | 
 | 58 | +    denoiser: Optional[nn.Module] = None,  | 
 | 59 | +):  | 
 | 60 | +    if config is None:  | 
 | 61 | +        config = PyramidAttentionBroadcastConfig()  | 
 | 62 | +      | 
 | 63 | +    if config.spatial_attention_block_skip is None and config.temporal_attention_block_skip is None and config.cross_attention_block_skip is None:  | 
 | 64 | +        logger.warning(  | 
 | 65 | +            "Pyramid Attention Broadcast requires one or more of `spatial_attention_block_skip`, `temporal_attention_block_skip` "  | 
 | 66 | +            "or `cross_attention_block_skip` parameters to be set to an integer, not `None`. Defaulting to using `spatial_attention_block_skip=2`. "  | 
 | 67 | +            "To avoid this warning, please set one of the above parameters."  | 
 | 68 | +        )  | 
 | 69 | +        config.spatial_attention_block_skip = 2  | 
 | 70 | +      | 
 | 71 | +    if denoiser is None:  | 
 | 72 | +        denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet  | 
 | 73 | +          | 
 | 74 | +    for name, module in denoiser.named_modules():  | 
 | 75 | +        if not isinstance(module, _ATTENTION_CLASSES):  | 
 | 76 | +            continue  | 
 | 77 | +        if isinstance(module, Attention):  | 
 | 78 | +            _apply_pyramid_attention_broadcast_on_attention_class(pipeline, name, module, config)  | 
 | 79 | + | 
 | 80 | + | 
 | 81 | +# def apply_pyramid_attention_broadcast_spatial(module: TypeVar[_ATTENTION_CLASSES], config: PyramidAttentionBroadcastConfig):  | 
 | 82 | +#     hook = PyramidAttentionBroadcastHook(skip_callback=)  | 
 | 83 | +#     add_hook_to_module(module)  | 
 | 84 | + | 
 | 85 | + | 
 | 86 | +def _apply_pyramid_attention_broadcast_on_attention_class(pipeline: DiffusionPipeline, name: str, module: Attention, config: PyramidAttentionBroadcastConfig):  | 
 | 87 | +    # Similar check as PEFT to determine if a string layer name matches a module name  | 
 | 88 | +    is_spatial_self_attention = (  | 
 | 89 | +        any(f"{identifier}." in name or identifier == name for identifier in config.spatial_attention_block_identifiers)  | 
 | 90 | +        and config.spatial_attention_timestep_skip_range is not None  | 
 | 91 | +        and not module.is_cross_attention  | 
 | 92 | +    )  | 
 | 93 | +    is_temporal_self_attention = (  | 
 | 94 | +        any(f"{identifier}." in name or identifier == name for identifier in config.temporal_attention_block_identifiers)  | 
 | 95 | +        and config.temporal_attention_timestep_skip_range is not None  | 
 | 96 | +        and not module.is_cross_attention  | 
 | 97 | +    )  | 
 | 98 | +    is_cross_attention = (  | 
 | 99 | +        any(f"{identifier}." in name or identifier == name for identifier in config.cross_attention_block_identifiers)  | 
 | 100 | +        and config.cross_attention_timestep_skip_range is not None  | 
 | 101 | +        and not module.is_cross_attention  | 
 | 102 | +    )  | 
 | 103 | + | 
 | 104 | +    if is_spatial_self_attention:  | 
 | 105 | +        apply_pyramid_attention_broadcast_spatial(module, config)  | 
 | 106 | +    elif is_temporal_self_attention:  | 
 | 107 | +        apply_pyramid_attention_broadcast_temporal(module, config)  | 
 | 108 | +    elif is_cross_attention:  | 
 | 109 | +        apply_pyramid_attention_broadcast_cross(module, config)  | 
 | 110 | +    else:  | 
 | 111 | +        logger.warning(f"Unable to apply Pyramid Attention Broadcast to the selected layer: {name}.")  | 
 | 112 | + | 
 | 113 | + | 
27 | 114 | class PyramidAttentionBroadcastMixin:  | 
28 | 115 |     r"""Mixin class for [Pyramid Attention Broadcast](https://www.arxiv.org/abs/2408.12588)."""  | 
29 | 116 | 
 
  | 
 | 
0 commit comments