Skip to content

Commit 98771d3

Browse files
committed
update
1 parent 64dec70 commit 98771d3

File tree

5 files changed

+375
-23
lines changed

5 files changed

+375
-23
lines changed

src/diffusers/hooks/_common.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from ..models.attention_processor import Attention, MochiAttention
16+
17+
18+
_ATTENTION_CLASSES = (Attention, MochiAttention)
19+
20+
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
21+
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
22+
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
23+
24+
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
25+
{
26+
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
27+
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
28+
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
29+
}
30+
)
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import inspect
16+
from dataclasses import dataclass
17+
from typing import Tuple, Union
18+
19+
import torch
20+
21+
from ..utils import get_logger
22+
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
23+
from .hooks import HookRegistry, ModelHook
24+
from .utils import _extract_return_information
25+
26+
27+
logger = get_logger(__name__) # pylint: disable=invalid-name
28+
29+
30+
_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
31+
_FBC_BLOCK_HOOK = "fbc_block_hook"
32+
33+
34+
@dataclass
35+
class FirstBlockCacheConfig:
36+
r"""
37+
Configuration for [First Block
38+
Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
39+
40+
Args:
41+
threshold (`float`, defaults to `0.05`):
42+
The threshold to determine whether or not a forward pass through all layers of the model is required. A
43+
higher threshold usually results in lower number of forward passes and faster inference, but might lead to
44+
poorer generation quality. A lower threshold may not result in significant generation speedup. The
45+
threshold is compared against the absmean difference of the residuals between the current and cached
46+
outputs from the first transformer block. If the difference is below the threshold, the forward pass is
47+
skipped.
48+
"""
49+
50+
threshold: float = 0.05
51+
52+
53+
class FBCSharedBlockState:
54+
def __init__(self) -> None:
55+
self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
56+
self.head_block_residual: torch.Tensor = None
57+
self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
58+
self.should_compute: bool = True
59+
60+
def reset(self):
61+
self.tail_block_residuals = None
62+
self.should_compute = True
63+
64+
def __repr__(self):
65+
return f"FirstBlockCacheSharedState(cache={self.cache})"
66+
67+
68+
class FBCHeadBlockHook(ModelHook):
69+
_is_stateful = True
70+
71+
def __init__(self, shared_state: FBCSharedBlockState, threshold: float):
72+
self.shared_state = shared_state
73+
self.threshold = threshold
74+
75+
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
76+
inputs = inspect.signature(module.__class__.forward)
77+
inputs_index_to_str = dict(enumerate(inputs.parameters.keys()))
78+
inputs_str_to_index = {v: k for k, v in inputs_index_to_str.items()}
79+
80+
try:
81+
outputs = _extract_return_information(module.__class__.forward)
82+
outputs_index_to_str = dict(enumerate(outputs))
83+
outputs_str_to_index = {v: k for k, v in outputs_index_to_str.items()}
84+
except RuntimeError:
85+
logger.error(f"Failed to extract return information for {module.__class__}")
86+
raise NotImplementedError(
87+
f"Module {module.__class__} is not supported with FirstBlockCache. Please open an issue at "
88+
f"https://github.com/huggingface/diffusers to notify us about the error with a minimal example "
89+
f"in order for us to add support for this module."
90+
)
91+
92+
self._inputs_index_to_str = inputs_index_to_str
93+
self._inputs_str_to_index = inputs_str_to_index
94+
self._outputs_index_to_str = outputs_index_to_str
95+
self._outputs_str_to_index = outputs_str_to_index
96+
return module
97+
98+
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
99+
hs_input_idx = self._inputs_str_to_index.get("hidden_states")
100+
ehs_input_idx = self._inputs_str_to_index.get("encoder_hidden_states", None)
101+
original_hs = kwargs.get("hidden_states", None)
102+
original_ehs = kwargs.get("encoder_hidden_states", None)
103+
original_hs = original_hs if original_hs is not None else args[hs_input_idx]
104+
if ehs_input_idx is not None:
105+
original_ehs = original_ehs if original_ehs is not None else args[ehs_input_idx]
106+
107+
hs_output_idx = self._outputs_str_to_index.get("hidden_states")
108+
ehs_output_idx = self._outputs_str_to_index.get("encoder_hidden_states", None)
109+
assert (ehs_input_idx is None) == (ehs_output_idx is None)
110+
111+
output = self.fn_ref.original_forward(*args, **kwargs)
112+
113+
hs_residual = None
114+
if isinstance(output, tuple):
115+
hs_residual = output[hs_output_idx] - original_hs
116+
else:
117+
hs_residual = output - original_hs
118+
119+
should_compute = self._should_compute_remaining_blocks(hs_residual)
120+
self.shared_state.should_compute = should_compute
121+
122+
hs, ehs = None, None
123+
if not should_compute:
124+
# Apply caching
125+
logger.info("Skipping forward pass through remaining blocks")
126+
hs = self.shared_state.tail_block_residuals[0] + output[hs_output_idx]
127+
if ehs_output_idx is not None:
128+
ehs = self.shared_state.tail_block_residuals[1] + output[ehs_output_idx]
129+
130+
if isinstance(output, tuple):
131+
return_output = [None] * len(output)
132+
return_output[hs_output_idx] = hs
133+
return_output[ehs_output_idx] = ehs
134+
return_output = tuple(return_output)
135+
else:
136+
return_output = hs
137+
return return_output
138+
else:
139+
logger.info("Computing forward pass through remaining blocks")
140+
if isinstance(output, tuple):
141+
head_block_output = [None] * len(output)
142+
head_block_output[0] = output[hs_output_idx]
143+
head_block_output[1] = output[ehs_output_idx]
144+
else:
145+
head_block_output = output
146+
self.shared_state.head_block_output = head_block_output
147+
self.shared_state.head_block_residual = hs_residual
148+
return output
149+
150+
def reset_state(self, module):
151+
self.shared_state.reset()
152+
return module
153+
154+
def _should_compute_remaining_blocks(self, hs_residual: torch.Tensor) -> bool:
155+
if self.shared_state.head_block_residual is None:
156+
return True
157+
prev_hs_residual = self.shared_state.head_block_residual
158+
hs_absmean = (hs_residual - prev_hs_residual).abs().mean()
159+
prev_hs_mean = prev_hs_residual.abs().mean()
160+
diff = (hs_absmean / prev_hs_mean).item()
161+
logger.info(f"Diff: {diff}, Threshold: {self.threshold}")
162+
return diff > self.threshold
163+
164+
165+
class FBCBlockHook(ModelHook):
166+
def __init__(self, shared_state: FBCSharedBlockState, is_tail: bool = False):
167+
super().__init__()
168+
self.shared_state = shared_state
169+
self.is_tail = is_tail
170+
171+
def initialize_hook(self, module):
172+
inputs = inspect.signature(module.__class__.forward)
173+
inputs_index_to_str = dict(enumerate(inputs.parameters.keys()))
174+
inputs_str_to_index = {v: k for k, v in inputs_index_to_str.items()}
175+
176+
try:
177+
outputs = _extract_return_information(module.__class__.forward)
178+
outputs_index_to_str = dict(enumerate(outputs))
179+
outputs_str_to_index = {v: k for k, v in outputs_index_to_str.items()}
180+
except RuntimeError:
181+
logger.error(f"Failed to extract return information for {module.__class__}")
182+
raise NotImplementedError(
183+
f"Module {module.__class__} is not supported with FirstBlockCache. Please open an issue at "
184+
f"https://github.com/huggingface/diffusers to notify us about the error with a minimal example "
185+
f"in order for us to add support for this module."
186+
)
187+
188+
self._inputs_index_to_str = inputs_index_to_str
189+
self._inputs_str_to_index = inputs_str_to_index
190+
self._outputs_index_to_str = outputs_index_to_str
191+
self._outputs_str_to_index = outputs_str_to_index
192+
return module
193+
194+
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
195+
hs_input_idx = self._inputs_str_to_index.get("hidden_states")
196+
ehs_input_idx = self._inputs_str_to_index.get("encoder_hidden_states", None)
197+
original_hs = kwargs.get("hidden_states", None)
198+
original_ehs = kwargs.get("encoder_hidden_states", None)
199+
original_hs = original_hs if original_hs is not None else args[hs_input_idx]
200+
if ehs_input_idx is not None:
201+
original_ehs = original_ehs if original_ehs is not None else args[ehs_input_idx]
202+
203+
hs_output_idx = self._outputs_str_to_index.get("hidden_states")
204+
ehs_output_idx = self._outputs_str_to_index.get("encoder_hidden_states", None)
205+
assert (ehs_input_idx is None) == (ehs_output_idx is None)
206+
207+
if self.shared_state.should_compute:
208+
output = self.fn_ref.original_forward(*args, **kwargs)
209+
if self.is_tail:
210+
hs_residual, ehs_residual = None, None
211+
if isinstance(output, tuple):
212+
hs_residual = output[hs_output_idx] - self.shared_state.head_block_output[0]
213+
ehs_residual = output[ehs_output_idx] - self.shared_state.head_block_output[1]
214+
else:
215+
hs_residual = output - self.shared_state.head_block_output
216+
self.shared_state.tail_block_residuals = (hs_residual, ehs_residual)
217+
return output
218+
219+
output_count = len(self._outputs_index_to_str.keys())
220+
return_output = [None] * output_count if output_count > 1 else original_hs
221+
if output_count == 1:
222+
return_output = original_hs
223+
else:
224+
return_output[hs_output_idx] = original_hs
225+
return_output[ehs_output_idx] = original_ehs
226+
return return_output
227+
228+
229+
def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
230+
shared_state = FBCSharedBlockState()
231+
remaining_blocks = []
232+
233+
for name, submodule in module.named_children():
234+
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
235+
continue
236+
for block in submodule:
237+
remaining_blocks.append((name, block))
238+
239+
head_block_name, head_block = remaining_blocks.pop(0)
240+
tail_block_name, tail_block = remaining_blocks.pop(-1)
241+
242+
logger.debug(f"Apply FBCHeadBlockHook to '{head_block_name}'")
243+
apply_fbc_head_block_hook(head_block, shared_state, config.threshold)
244+
245+
for name, block in remaining_blocks:
246+
logger.debug(f"Apply FBCBlockHook to '{name}'")
247+
apply_fbc_block_hook(block, shared_state)
248+
249+
logger.debug(f"Apply FBCBlockHook to tail block '{tail_block_name}'")
250+
apply_fbc_block_hook(tail_block, shared_state, is_tail=True)
251+
252+
253+
def apply_fbc_head_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, threshold: float) -> None:
254+
registry = HookRegistry.check_if_exists_or_initialize(block)
255+
hook = FBCHeadBlockHook(state, threshold)
256+
registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
257+
258+
259+
def apply_fbc_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, is_tail: bool = False) -> None:
260+
registry = HookRegistry.check_if_exists_or_initialize(block)
261+
hook = FBCBlockHook(state, is_tail)
262+
registry.register_hook(hook, _FBC_BLOCK_HOOK)

src/diffusers/hooks/pyramid_attention_broadcast.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,18 @@
2020

2121
from ..models.attention_processor import Attention, MochiAttention
2222
from ..utils import logging
23+
from ._common import (
24+
_ATTENTION_CLASSES,
25+
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
26+
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
27+
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
28+
)
2329
from .hooks import HookRegistry, ModelHook
2430

2531

2632
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2733

2834

29-
_ATTENTION_CLASSES = (Attention, MochiAttention)
30-
31-
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
32-
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
33-
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
34-
35-
3635
@dataclass
3736
class PyramidAttentionBroadcastConfig:
3837
r"""
@@ -76,9 +75,9 @@ class PyramidAttentionBroadcastConfig:
7675
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
7776
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
7877

79-
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
80-
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
81-
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
78+
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS
79+
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS
80+
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS
8281

8382
current_timestep_callback: Callable[[], int] = None
8483

src/diffusers/hooks/utils.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import ast
2+
import inspect
3+
import textwrap
4+
from typing import List
5+
6+
7+
def _extract_return_information(func) -> List[str]:
8+
"""Extracts return variable names in order from a function."""
9+
try:
10+
source = inspect.getsource(func)
11+
source = textwrap.dedent(source) # Modify indentation to make parsing compatible
12+
except (OSError, TypeError):
13+
try:
14+
source_file = inspect.getfile(func)
15+
with open(source_file, "r", encoding="utf-8") as f:
16+
source = f.read()
17+
18+
# Extract function definition manually
19+
source_lines = source.splitlines()
20+
func_name = func.__name__
21+
start_line = None
22+
indent_level = None
23+
extracted_lines = []
24+
25+
for i, line in enumerate(source_lines):
26+
stripped = line.strip()
27+
if stripped.startswith(f"def {func_name}("):
28+
start_line = i
29+
indent_level = len(line) - len(line.lstrip())
30+
extracted_lines.append(line)
31+
continue
32+
33+
if start_line is not None:
34+
# Stop when indentation level decreases (end of function)
35+
current_indent = len(line) - len(line.lstrip())
36+
if current_indent <= indent_level and line.strip():
37+
break
38+
extracted_lines.append(line)
39+
40+
source = "\n".join(extracted_lines)
41+
except Exception as e:
42+
raise RuntimeError(f"Failed to retrieve function source: {e}")
43+
44+
# Parse source code using AST
45+
tree = ast.parse(source)
46+
return_vars = []
47+
48+
class ReturnVisitor(ast.NodeVisitor):
49+
def visit_Return(self, node):
50+
if isinstance(node.value, ast.Tuple):
51+
# Multiple return values
52+
return_vars.extend(var.id for var in node.value.elts if isinstance(var, ast.Name))
53+
elif isinstance(node.value, ast.Name):
54+
# Single return value
55+
return_vars.append(node.value.id)
56+
57+
visitor = ReturnVisitor()
58+
visitor.visit(tree)
59+
return return_vars

0 commit comments

Comments
 (0)