Skip to content

Commit 07798ba

Browse files
committed
update
1 parent 9c7e205 commit 07798ba

File tree

3 files changed

+100
-10
lines changed

3 files changed

+100
-10
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
import torch.distributed as dist
17+
18+
from ..utils import get_logger
19+
from ._common import _BATCHED_INPUT_IDENTIFIERS
20+
from .hooks import HookRegistry, ModelHook
21+
22+
23+
logger = get_logger(__name__) # pylint: disable=invalid-name
24+
25+
_CFG_PARALLEL = "cfg_parallel"
26+
27+
28+
class CFGParallelHook(ModelHook):
29+
def initialize_hook(self, module):
30+
if not dist.is_initialized():
31+
raise RuntimeError("Distributed environment not initialized.")
32+
return module
33+
34+
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
35+
if len(args) > 0:
36+
logger.warning(
37+
"CFGParallelHook is an example hook that does not work with batched positional arguments. Please use with caution."
38+
)
39+
40+
world_size = dist.get_world_size()
41+
rank = dist.get_rank()
42+
43+
assert world_size == 2, "This is an example hook designed to only work with 2 processes."
44+
45+
for key in list(kwargs.keys()):
46+
if key not in _BATCHED_INPUT_IDENTIFIERS or kwargs[key] is None:
47+
continue
48+
kwargs[key] = torch.chunk(kwargs[key], world_size, dim=0)[rank].contiguous()
49+
50+
output = self.fn_ref.original_forward(*args, **kwargs)
51+
sample = output[0]
52+
sample_list = [torch.empty_like(sample) for _ in range(world_size)]
53+
dist.all_gather(sample_list, sample)
54+
sample = torch.cat(sample_list, dim=0).contiguous()
55+
56+
return_dict = kwargs.get("return_dict", False)
57+
if not return_dict:
58+
return (sample, *output[1:])
59+
return output.__class__(sample, *output[1:])
60+
61+
62+
def apply_cfg_parallel(module: torch.nn.Module) -> None:
63+
registry = HookRegistry.check_if_exists_or_initialize(module)
64+
hook = CFGParallelHook()
65+
registry.register_hook(hook, _CFG_PARALLEL)

src/diffusers/hooks/_common.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from ..models.attention_processor import Attention, MochiAttention
2+
3+
4+
_ATTENTION_CLASSES = (Attention, MochiAttention)
5+
6+
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
7+
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
8+
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
9+
10+
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
11+
{
12+
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
13+
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
14+
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
15+
}
16+
)
17+
18+
_BATCHED_INPUT_IDENTIFIERS = (
19+
"hidden_states",
20+
"encoder_hidden_states",
21+
"pooled_projections",
22+
"timestep",
23+
"attention_mask",
24+
"encoder_attention_mask",
25+
"guidance",
26+
)

src/diffusers/hooks/pyramid_attention_broadcast.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,18 @@
2020

2121
from ..models.attention_processor import Attention, MochiAttention
2222
from ..utils import logging
23+
from ._common import (
24+
_ATTENTION_CLASSES,
25+
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
26+
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
27+
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
28+
)
2329
from .hooks import HookRegistry, ModelHook
2430

2531

2632
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2733

2834

29-
_ATTENTION_CLASSES = (Attention, MochiAttention)
30-
31-
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
32-
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
33-
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
34-
35-
3635
@dataclass
3736
class PyramidAttentionBroadcastConfig:
3837
r"""
@@ -76,9 +75,9 @@ class PyramidAttentionBroadcastConfig:
7675
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
7776
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
7877

79-
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
80-
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
81-
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
78+
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS
79+
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS
80+
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS
8281

8382
current_timestep_callback: Callable[[], int] = None
8483

0 commit comments

Comments
 (0)