diff --git a/src/diffusers/hooks/_cfg_parallel.py b/src/diffusers/hooks/_cfg_parallel.py new file mode 100644 index 000000000000..ca4045c39513 --- /dev/null +++ b/src/diffusers/hooks/_cfg_parallel.py @@ -0,0 +1,65 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed as dist + +from ..utils import get_logger +from ._common import _BATCHED_INPUT_IDENTIFIERS +from .hooks import HookRegistry, ModelHook + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_CFG_PARALLEL = "cfg_parallel" + + +class CFGParallelHook(ModelHook): + def initialize_hook(self, module): + if not dist.is_initialized(): + raise RuntimeError("Distributed environment not initialized.") + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + if len(args) > 0: + logger.warning( + "CFGParallelHook is an example hook that does not work with batched positional arguments. Please use with caution." + ) + + world_size = dist.get_world_size() + rank = dist.get_rank() + + assert world_size == 2, "This is an example hook designed to only work with 2 processes." + + for key in list(kwargs.keys()): + if key not in _BATCHED_INPUT_IDENTIFIERS or kwargs[key] is None: + continue + kwargs[key] = torch.chunk(kwargs[key], world_size, dim=0)[rank].contiguous() + + output = self.fn_ref.original_forward(*args, **kwargs) + sample = output[0] + sample_list = [torch.empty_like(sample) for _ in range(world_size)] + dist.all_gather(sample_list, sample) + sample = torch.cat(sample_list, dim=0).contiguous() + + return_dict = kwargs.get("return_dict", False) + if not return_dict: + return (sample, *output[1:]) + return output.__class__(sample, *output[1:]) + + +def apply_cfg_parallel(module: torch.nn.Module) -> None: + registry = HookRegistry.check_if_exists_or_initialize(module) + hook = CFGParallelHook() + registry.register_hook(hook, _CFG_PARALLEL) diff --git a/src/diffusers/hooks/_common.py b/src/diffusers/hooks/_common.py new file mode 100644 index 000000000000..8c0d71371b1d --- /dev/null +++ b/src/diffusers/hooks/_common.py @@ -0,0 +1,26 @@ +from ..models.attention_processor import Attention, MochiAttention + + +_ATTENTION_CLASSES = (Attention, MochiAttention) + +_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers") +_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) +_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers") + +_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple( + { + *_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS, + *_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS, + *_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS, + } +) + +_BATCHED_INPUT_IDENTIFIERS = ( + "hidden_states", + "encoder_hidden_states", + "pooled_projections", + "timestep", + "attention_mask", + "encoder_attention_mask", + "guidance", +) diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py index 9f8597d52f8c..e6c06aaa4456 100644 --- a/src/diffusers/hooks/pyramid_attention_broadcast.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -20,19 +20,18 @@ from ..models.attention_processor import Attention, MochiAttention from ..utils import logging +from ._common import ( + _ATTENTION_CLASSES, + _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS, + _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS, + _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS, +) from .hooks import HookRegistry, ModelHook logger = logging.get_logger(__name__) # pylint: disable=invalid-name -_ATTENTION_CLASSES = (Attention, MochiAttention) - -_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks") -_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) -_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks") - - @dataclass class PyramidAttentionBroadcastConfig: r""" @@ -76,9 +75,9 @@ class PyramidAttentionBroadcastConfig: temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800) cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800) - spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS - temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS - cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS + spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS + temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS + cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS current_timestep_callback: Callable[[], int] = None