Skip to content

Commit fe93975

Browse files
committed
address review comments
1 parent 0a290a6 commit fe93975

File tree

3 files changed

+45
-34
lines changed

3 files changed

+45
-34
lines changed

src/diffusers/hooks/pyramid_attention_broadcast.py

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -137,20 +137,34 @@ class PyramidAttentionBroadcastHook(ModelHook):
137137

138138
_is_stateful = True
139139

140-
def __init__(self, skip_callback: Callable[[torch.nn.Module], bool]) -> None:
140+
def __init__(
141+
self, timestep_skip_range: Tuple[int, int], block_skip_range: int, current_timestep_callback: Callable[[], int]
142+
) -> None:
141143
super().__init__()
142144

143-
self.skip_callback = skip_callback
145+
self.timestep_skip_range = timestep_skip_range
146+
self.block_skip_range = block_skip_range
147+
self.current_timestep_callback = current_timestep_callback
144148

145149
def initialize_hook(self, module):
146150
self.state = PyramidAttentionBroadcastState()
147151
return module
148152

149153
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
150-
if self.skip_callback(module):
151-
output = self.state.cache
152-
else:
154+
is_within_timestep_range = (
155+
self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1]
156+
)
157+
should_compute_attention = (
158+
self.state.cache is None
159+
or self.state.iteration == 0
160+
or not is_within_timestep_range
161+
or self.state.iteration % self.block_skip_range == 0
162+
)
163+
164+
if should_compute_attention:
153165
output = module._old_forward(*args, **kwargs)
166+
else:
167+
output = self.state.cache
154168

155169
self.state.cache = output
156170
self.state.iteration += 1
@@ -266,44 +280,35 @@ def _apply_pyramid_attention_broadcast_on_attention_class(
266280
)
267281
return False
268282

269-
def skip_callback(module: torch.nn.Module) -> bool:
270-
hook: PyramidAttentionBroadcastHook = module._diffusers_hook.get_hook("pyramid_attention_broadcast")
271-
pab_state: PyramidAttentionBroadcastState = hook.state
272-
273-
if pab_state.cache is None:
274-
return False
275-
276-
is_within_timestep_range = timestep_skip_range[0] < config.current_timestep_callback() < timestep_skip_range[1]
277-
if not is_within_timestep_range:
278-
# We are still not in the phase of inference where skipping attention is possible without minimal quality
279-
# loss, as described in the paper. So, the attention computation cannot be skipped
280-
return False
281-
282-
should_compute_attention = pab_state.iteration > 0 and pab_state.iteration % block_skip_range == 0
283-
return not should_compute_attention
284-
285283
logger.debug(f"Enabling Pyramid Attention Broadcast ({block_type}) in layer: {name}")
286-
_apply_pyramid_attention_broadcast(module, skip_callback)
284+
_apply_pyramid_attention_broadcast_hook(
285+
module, timestep_skip_range, block_skip_range, config.current_timestep_callback
286+
)
287287
return True
288288

289289

290-
def _apply_pyramid_attention_broadcast(
290+
def _apply_pyramid_attention_broadcast_hook(
291291
module: Union[Attention, MochiAttention],
292-
skip_callback: Callable[[torch.nn.Module], bool],
292+
timestep_skip_range: Tuple[int, int],
293+
block_skip_range: int,
294+
current_timestep_callback: Callable[[], int],
293295
):
294296
r"""
295297
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given torch.nn.Module.
296298
297299
Args:
298300
module (`torch.nn.Module`):
299301
The module to apply Pyramid Attention Broadcast to.
300-
skip_callback (`Callable[[nn.Module], bool]`):
301-
A callback function that determines whether the attention computation should be skipped or not. The
302-
callback function should return a boolean value, where `True` indicates that the attention computation
303-
should be skipped, and `False` indicates that the attention computation should not be skipped. The callback
304-
function will receive a torch.nn.Module containing a `_pyramid_attention_broadcast_state` attribute that
305-
can should be used to retrieve and update the state of PAB for the given module.
302+
timestep_skip_range (`Tuple[int, int]`):
303+
The range of timesteps to skip in the attention layer. The attention computations will be conditionally
304+
skipped if the current timestep is within the specified range.
305+
block_skip_range (`int`):
306+
The number of times a specific attention broadcast is skipped before computing the attention states to
307+
re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e., old
308+
attention states will be re-used) before computing the new attention states again.
309+
current_timestep_callback (`Callable[[], int]`):
310+
A callback function that returns the current inference timestep.
306311
"""
307312
registry = HookRegistry.check_if_exists_or_initialize(module)
308-
hook = PyramidAttentionBroadcastHook(skip_callback)
313+
hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback)
309314
registry.register_hook(hook, "pyramid_attention_broadcast")

src/diffusers/models/cache_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from ..utils.logging import get_logger
16+
17+
18+
logger = get_logger(__name__) # pylint: disable=invalid-name
19+
1520

1621
class CacheMixin:
1722
r"""
@@ -67,7 +72,8 @@ def disable_cache(self) -> None:
6772
from ..hooks import HookRegistry, PyramidAttentionBroadcastConfig
6873

6974
if self._cache_config is None:
70-
raise ValueError("Caching techniques have not been enabled.")
75+
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
76+
return
7177

7278
if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
7379
registry = HookRegistry.check_if_exists_or_initialize(self)

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,8 +1143,8 @@ def maybe_free_model_hooks(self):
11431143
Make sure to add this function to the end of the `__call__` function of your pipeline so that it functions
11441144
correctly when applying `enable_model_cpu_offload`.
11451145
"""
1146-
for name, component in self.components.items():
1147-
if name in ("transformer", "unet") and hasattr(component, "_reset_stateful_cache"):
1146+
for component in self.components.values():
1147+
if hasattr(component, "_reset_stateful_cache"):
11481148
component._reset_stateful_cache()
11491149

11501150
if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0:

0 commit comments

Comments
 (0)