Skip to content

Commit 9182f57

Browse files
committed
refactor
1 parent 0b2629d commit 9182f57

File tree

3 files changed

+36
-83
lines changed

3 files changed

+36
-83
lines changed

src/diffusers/models/hooks.py

Lines changed: 1 addition & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import functools
16-
from typing import Any, Callable, Dict, Tuple
16+
from typing import Any, Dict, Tuple
1717

1818
import torch
1919

@@ -78,9 +78,6 @@ def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
7878
"""
7979
return module
8080

81-
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
82-
return module
83-
8481

8582
class SequentialHook(ModelHook):
8683
r"""A hook that can contain several hooks and iterates through them at each event."""
@@ -108,62 +105,6 @@ def detach_hook(self, module):
108105
module = hook.detach_hook(module)
109106
return module
110107

111-
def reset_state(self, module):
112-
for hook in self.hooks:
113-
module = hook.reset_state(module)
114-
return module
115-
116-
117-
class PyramidAttentionBroadcastHook(ModelHook):
118-
def __init__(self, skip_callback: Callable[[torch.nn.Module], bool]) -> None:
119-
super().__init__()
120-
121-
self.skip_callback = skip_callback
122-
123-
self.cache = None
124-
self._iteration = 0
125-
126-
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
127-
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
128-
129-
if self.cache is not None and self.skip_callback(module):
130-
output = self.cache
131-
else:
132-
output = module._old_forward(*args, **kwargs)
133-
134-
return module._diffusers_hook.post_forward(module, output)
135-
136-
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
137-
self.cache = output
138-
return output
139-
140-
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
141-
self.cache = None
142-
self._iteration = 0
143-
return module
144-
145-
146-
class LayerSkipHook(ModelHook):
147-
def __init__(self, skip_: Callable[[torch.nn.Module], bool]) -> None:
148-
super().__init__()
149-
150-
self.skip_callback = skip_
151-
152-
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
153-
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
154-
155-
if self.skip_callback(module):
156-
# We want to skip this layer, so we have to return the input of the current layer
157-
# as output of the next layer. But at this point, we don't have information about
158-
# the arguments required by next layer. Even if we did, order matters unless we
159-
# always pass kwargs. But that is not the case usually with hidden_states, encoder_hidden_states,
160-
# temb, etc. TODO(aryan): implement correctly later
161-
output = None
162-
else:
163-
output = module._old_forward(*args, **kwargs)
164-
165-
return module._diffusers_hook.post_forward(module, output)
166-
167108

168109
def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False):
169110
r"""

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,10 +1088,6 @@ def maybe_free_model_hooks(self):
10881088
is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it
10891089
functions correctly when applying enable_model_cpu_offload.
10901090
"""
1091-
1092-
if hasattr(self, "_diffusers_hook"):
1093-
self._diffusers_hook.reset_state()
1094-
10951091
if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0:
10961092
# `enable_model_cpu_offload` has not be called, so silently do nothing
10971093
return

src/diffusers/pipelines/pyramid_attention_broadcast_utils.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
# limitations under the License.
1414

1515
from dataclasses import dataclass
16-
from typing import Callable, Optional, Protocol, Tuple
16+
from typing import Any, Callable, Optional, Tuple
1717

1818
import torch.nn as nn
1919

2020
from ..models.attention_processor import Attention
21-
from ..models.hooks import PyramidAttentionBroadcastHook, add_hook_to_module
21+
from ..models.hooks import ModelHook, add_hook_to_module
2222
from ..utils import logging
2323
from .pipeline_utils import DiffusionPipeline
2424

@@ -28,7 +28,7 @@
2828

2929
_ATTENTION_CLASSES = (Attention,)
3030

31-
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
31+
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
3232
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
3333
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
3434

@@ -96,21 +96,15 @@ class PyramidAttentionBroadcastState:
9696

9797
def __init__(self) -> None:
9898
self.iteration = 0
99+
self.cache = None
100+
101+
def update_state(self, output: Any) -> None:
102+
self.iteration += 1
103+
self.cache = output
99104

100105
def reset_state(self):
101106
self.iteration = 0
102-
103-
104-
class nnModulePAB(Protocol):
105-
r"""
106-
Type hint for a torch.nn.Module that contains a `_pyramid_attention_broadcast_state` attribute.
107-
108-
Attributes:
109-
_pyramid_attention_broadcast_state (`PyramidAttentionBroadcastState`):
110-
The state of Pyramid Attention Broadcast.
111-
"""
112-
113-
_pyramid_attention_broadcast_state: PyramidAttentionBroadcastState
107+
self.cache = None
114108

115109

116110
def apply_pyramid_attention_broadcast(
@@ -247,14 +241,15 @@ def _apply_pyramid_attention_broadcast_on_attention_class(
247241
)
248242
return
249243

250-
def skip_callback(module: nnModulePAB) -> bool:
244+
def skip_callback(module: nn.Module) -> bool:
251245
pab_state = module._pyramid_attention_broadcast_state
252-
current_timestep = pipeline._current_timestep
253-
is_within_timestep_range = timestep_skip_range[0] < current_timestep < timestep_skip_range[1]
246+
if pab_state.cache is None:
247+
return False
248+
249+
is_within_timestep_range = timestep_skip_range[0] < pipeline._current_timestep < timestep_skip_range[1]
254250

255251
if is_within_timestep_range:
256252
should_compute_attention = pab_state.iteration > 0 and pab_state.iteration % block_skip_range == 0
257-
pab_state.iteration += 1
258253
return not should_compute_attention
259254

260255
# We are still not in the phase of inference where skipping attention is possible without minimal quality
@@ -263,3 +258,24 @@ def skip_callback(module: nnModulePAB) -> bool:
263258

264259
logger.debug(f"Enabling Pyramid Attention Broadcast ({block_type}) in layer: {name}")
265260
apply_pyramid_attention_broadcast_on_module(module, skip_callback)
261+
262+
263+
class PyramidAttentionBroadcastHook(ModelHook):
264+
def __init__(self, skip_callback: Callable[[nn.Module], bool]) -> None:
265+
super().__init__()
266+
267+
self.skip_callback = skip_callback
268+
269+
def new_forward(self, module: nn.Module, *args, **kwargs) -> Any:
270+
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
271+
272+
if self.skip_callback(module):
273+
output = module._pyramid_attention_broadcast_state.cache
274+
else:
275+
output = module._old_forward(*args, **kwargs)
276+
277+
return module._diffusers_hook.post_forward(module, output)
278+
279+
def post_forward(self, module: nn.Module, output: Any) -> Any:
280+
module._pyramid_attention_broadcast_state.update_state(output)
281+
return output

0 commit comments

Comments
 (0)