Skip to content

Commit 1afc0fc

Browse files
committed
update
1 parent 995b82f commit 1afc0fc

File tree

2 files changed

+91
-38
lines changed

2 files changed

+91
-38
lines changed

src/diffusers/models/hooks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import functools
16-
from typing import Any, Callable, Dict, Tuple, Union
16+
from typing import Any, Callable, Dict, Tuple
1717

1818
import torch
1919

@@ -166,7 +166,7 @@ def __init__(self, skip_: Callable[[torch.nn.Module], bool]) -> None:
166166
super().__init__()
167167

168168
self.skip_callback = skip_
169-
169+
170170
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
171171
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
172172

@@ -179,7 +179,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
179179
output = None
180180
else:
181181
output = module._old_forward(*args, **kwargs)
182-
182+
183183
return module._diffusers_hook.post_forward(module, output)
184184

185185

src/diffusers/pipelines/pyramid_attention_broadcast_utils.py

Lines changed: 88 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from dataclasses import dataclass
16-
from typing import Callable, List, Optional, Tuple, Type, TypeVar
16+
from typing import Callable, List, Optional, Tuple
1717

1818
import torch.nn as nn
1919

@@ -28,28 +28,32 @@
2828

2929
_ATTENTION_CLASSES = (Attention,)
3030

31-
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ["blocks", "transformer_blocks"]
32-
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ["temporal_transformer_blocks"]
33-
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ["blocks", "transformer_blocks"]
31+
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
32+
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = "temporal_transformer_blocks"
33+
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
3434

3535

3636
@dataclass
3737
class PyramidAttentionBroadcastConfig:
38-
spatial_attention_block_skip = None
39-
temporal_attention_block_skip = None
40-
cross_attention_block_skip = None
41-
42-
spatial_attention_timestep_skip_range = (100, 800)
43-
temporal_attention_timestep_skip_range = (100, 800)
44-
cross_attention_timestep_skip_range = (100, 800)
38+
spatial_attention_block_skip_range: Optional[int] = None
39+
temporal_attention_block_skip_range: Optional[int] = None
40+
cross_attention_block_skip_range: Optional[int] = None
4541

46-
spatial_attention_block_identifiers = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
47-
temporal_attention_block_identifiers = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
48-
cross_attention_block_identifiers = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
42+
spatial_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
43+
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
44+
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
45+
46+
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
47+
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
48+
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
4949

5050

5151
class PyramidAttentionBroadcastState:
52-
iteration = 0
52+
def __init__(self) -> None:
53+
self.iteration = 0
54+
55+
def reset_state(self):
56+
self.iteration = 0
5357

5458

5559
def apply_pyramid_attention_broadcast(
@@ -59,56 +63,105 @@ def apply_pyramid_attention_broadcast(
5963
):
6064
if config is None:
6165
config = PyramidAttentionBroadcastConfig()
62-
63-
if config.spatial_attention_block_skip is None and config.temporal_attention_block_skip is None and config.cross_attention_block_skip is None:
66+
67+
if (
68+
config.spatial_attention_block_skip_range is None
69+
and config.temporal_attention_block_skip_range is None
70+
and config.cross_attention_block_skip_range is None
71+
):
6472
logger.warning(
65-
"Pyramid Attention Broadcast requires one or more of `spatial_attention_block_skip`, `temporal_attention_block_skip` "
66-
"or `cross_attention_block_skip` parameters to be set to an integer, not `None`. Defaulting to using `spatial_attention_block_skip=2`. "
73+
"Pyramid Attention Broadcast requires one or more of `spatial_attention_block_skip_range`, `temporal_attention_block_skip_range` "
74+
"or `cross_attention_block_skip_range` parameters to be set to an integer, not `None`. Defaulting to using `spatial_attention_block_skip_range=2`. "
6775
"To avoid this warning, please set one of the above parameters."
6876
)
69-
config.spatial_attention_block_skip = 2
70-
77+
config.spatial_attention_block_skip_range = 2
78+
7179
if denoiser is None:
7280
denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet
73-
81+
7482
for name, module in denoiser.named_modules():
7583
if not isinstance(module, _ATTENTION_CLASSES):
7684
continue
7785
if isinstance(module, Attention):
7886
_apply_pyramid_attention_broadcast_on_attention_class(pipeline, name, module, config)
7987

8088

81-
# def apply_pyramid_attention_broadcast_spatial(module: TypeVar[_ATTENTION_CLASSES], config: PyramidAttentionBroadcastConfig):
82-
# hook = PyramidAttentionBroadcastHook(skip_callback=)
83-
# add_hook_to_module(module)
89+
def apply_pyramid_attention_broadcast_on_module(
90+
module: Attention,
91+
block_skip_range: int,
92+
timestep_skip_range: Tuple[int, int],
93+
current_timestep_callback: Callable[[], int],
94+
):
95+
module._pyramid_attention_broadcast_state = PyramidAttentionBroadcastState()
96+
min_timestep, max_timestep = timestep_skip_range
97+
98+
def skip_callback(attention_module: nn.Module) -> bool:
99+
pab_state: PyramidAttentionBroadcastState = attention_module._pyramid_attention_broadcast_state
100+
current_timestep = current_timestep_callback()
101+
is_within_timestep_range = min_timestep < current_timestep < max_timestep
102+
103+
if is_within_timestep_range:
104+
# As soon as the current timestep is within the timestep range, we start skipping attention computation.
105+
# The following inference steps will compute the attention every `block_skip_range` steps.
106+
should_compute_attention = pab_state.iteration > 0 and pab_state.iteration % block_skip_range == 0
107+
pab_state.iteration += 1
108+
print(current_timestep, is_within_timestep_range, should_compute_attention)
109+
return not should_compute_attention
110+
111+
# We are still not yet in the phase of inference where skipping attention is possible without minimal quality
112+
# loss, as described in the paper. So, the attention computation cannot be skipped
113+
return False
84114

115+
hook = PyramidAttentionBroadcastHook(skip_callback=skip_callback)
116+
add_hook_to_module(module, hook, append=True)
85117

86-
def _apply_pyramid_attention_broadcast_on_attention_class(pipeline: DiffusionPipeline, name: str, module: Attention, config: PyramidAttentionBroadcastConfig):
118+
119+
def _apply_pyramid_attention_broadcast_on_attention_class(
120+
pipeline: DiffusionPipeline, name: str, module: Attention, config: PyramidAttentionBroadcastConfig
121+
):
87122
# Similar check as PEFT to determine if a string layer name matches a module name
88123
is_spatial_self_attention = (
89-
any(f"{identifier}." in name or identifier == name for identifier in config.spatial_attention_block_identifiers)
90-
and config.spatial_attention_timestep_skip_range is not None
124+
any(
125+
f"{identifier}." in name or identifier == name for identifier in config.spatial_attention_block_identifiers
126+
)
127+
and config.spatial_attention_block_skip_range is not None
91128
and not module.is_cross_attention
92129
)
93130
is_temporal_self_attention = (
94-
any(f"{identifier}." in name or identifier == name for identifier in config.temporal_attention_block_identifiers)
95-
and config.temporal_attention_timestep_skip_range is not None
131+
any(
132+
f"{identifier}." in name or identifier == name
133+
for identifier in config.temporal_attention_block_identifiers
134+
)
135+
and config.temporal_attention_block_skip_range is not None
96136
and not module.is_cross_attention
97137
)
98138
is_cross_attention = (
99139
any(f"{identifier}." in name or identifier == name for identifier in config.cross_attention_block_identifiers)
100-
and config.cross_attention_timestep_skip_range is not None
140+
and config.cross_attention_block_skip_range is not None
101141
and not module.is_cross_attention
102142
)
103143

144+
block_skip_range, timestep_skip_range = None, None
104145
if is_spatial_self_attention:
105-
apply_pyramid_attention_broadcast_spatial(module, config)
146+
block_skip_range = config.spatial_attention_block_skip_range
147+
timestep_skip_range = config.spatial_attention_timestep_skip_range
106148
elif is_temporal_self_attention:
107-
apply_pyramid_attention_broadcast_temporal(module, config)
149+
block_skip_range = config.temporal_attention_block_skip_range
150+
timestep_skip_range = config.temporal_attention_timestep_skip_range
108151
elif is_cross_attention:
109-
apply_pyramid_attention_broadcast_cross(module, config)
110-
else:
152+
block_skip_range = config.cross_attention_block_skip_range
153+
timestep_skip_range = config.cross_attention_timestep_skip_range
154+
155+
if block_skip_range is None or timestep_skip_range is None:
111156
logger.warning(f"Unable to apply Pyramid Attention Broadcast to the selected layer: {name}.")
157+
return
158+
159+
def current_timestep_callback():
160+
return pipeline._current_timestep
161+
162+
apply_pyramid_attention_broadcast_on_module(
163+
module, block_skip_range, timestep_skip_range, current_timestep_callback
164+
)
112165

113166

114167
class PyramidAttentionBroadcastMixin:

0 commit comments

Comments
 (0)