Skip to content

Commit 6090575

Browse files
committed
merge pyramid-attention-rewrite-2
1 parent d95d61a commit 6090575

14 files changed

+316
-209
lines changed

src/diffusers/models/hooks.py

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import functools
16-
from typing import Any, Callable, Dict, Tuple, Union
16+
from typing import Any, Callable, Dict, Tuple
1717

1818
import torch
1919

@@ -117,45 +117,72 @@ def reset_state(self, module):
117117
class PyramidAttentionBroadcastHook(ModelHook):
118118
def __init__(
119119
self,
120-
skip_range: int,
121-
timestep_range: Tuple[int, int],
122-
timestep_callback: Callable[[], Union[torch.LongTensor, int]],
120+
skip_callback: Callable[[torch.nn.Module], bool],
121+
# skip_range: int,
122+
# timestep_range: Tuple[int, int],
123+
# timestep_callback: Callable[[], Union[torch.LongTensor, int]],
123124
) -> None:
124125
super().__init__()
125126

126-
self.skip_range = skip_range
127-
self.timestep_range = timestep_range
128-
self.timestep_callback = timestep_callback
127+
# self.skip_range = skip_range
128+
# self.timestep_range = timestep_range
129+
# self.timestep_callback = timestep_callback
130+
self.skip_callback = skip_callback
129131

130-
self.attention_cache = None
132+
self.cache = None
131133
self._iteration = 0
132134

133135
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
134136
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
135137

136-
current_timestep = self.timestep_callback()
137-
is_within_timestep_range = self.timestep_range[0] < current_timestep < self.timestep_range[1]
138-
should_compute_attention = self._iteration % self.skip_range == 0
138+
# current_timestep = self.timestep_callback()
139+
# is_within_timestep_range = self.timestep_range[0] < current_timestep < self.timestep_range[1]
140+
# should_compute_attention = self._iteration % self.skip_range == 0
139141

140-
if not is_within_timestep_range or should_compute_attention:
141-
output = module._old_forward(*args, **kwargs)
142-
else:
143-
output = self.attention_cache
142+
# if not is_within_timestep_range or should_compute_attention:
143+
# output = module._old_forward(*args, **kwargs)
144+
# else:
145+
# output = self.attention_cache
144146

145-
self._iteration = self._iteration + 1
147+
if self.cache is not None and self.skip_callback(module):
148+
output = self.cache
149+
else:
150+
output = module._old_forward(*args, **kwargs)
146151

147152
return module._diffusers_hook.post_forward(module, output)
148153

149154
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
150-
self.attention_cache = output
155+
self.cache = output
151156
return output
152157

153158
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
154-
self.attention_cache = None
159+
self.cache = None
155160
self._iteration = 0
156161
return module
157162

158163

164+
class LayerSkipHook(ModelHook):
165+
def __init__(self, skip_: Callable[[torch.nn.Module], bool]) -> None:
166+
super().__init__()
167+
168+
self.skip_callback = skip_
169+
170+
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
171+
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
172+
173+
if self.skip_callback(module):
174+
# We want to skip this layer, so we have to return the input of the current layer
175+
# as output of the next layer. But at this point, we don't have information about
176+
# the arguments required by next layer. Even if we did, order matters unless we
177+
# always pass kwargs. But that is not the case usually with hidden_states, encoder_hidden_states,
178+
# temb, etc. TODO(aryan): implement correctly later
179+
output = None
180+
else:
181+
output = module._old_forward(*args, **kwargs)
182+
183+
return module._diffusers_hook.post_forward(module, output)
184+
185+
159186
def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False):
160187
r"""
161188
Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove

src/diffusers/pipelines/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@
5858
"StableDiffusionMixin",
5959
"ImagePipelineOutput",
6060
]
61+
_import_structure["pyramid_attention_broadcast_utils"] = [
62+
"PyramidAttentionBroadcastConfig",
63+
"apply_pyramid_attention_broadcast",
64+
"apply_pyramid_attention_broadcast_on_module",
65+
]
6166
_import_structure["deprecated"].extend(
6267
[
6368
"PNDMPipeline",
@@ -447,6 +452,11 @@
447452
ImagePipelineOutput,
448453
StableDiffusionMixin,
449454
)
455+
from .pyramid_attention_broadcast_utils import (
456+
PyramidAttentionBroadcastConfig,
457+
apply_pyramid_attention_broadcast,
458+
apply_pyramid_attention_broadcast_on_module,
459+
)
450460

451461
try:
452462
if not (is_torch_available() and is_librosa_available()):

src/diffusers/pipelines/allegro/pipeline_allegro.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
)
3939
from ...utils.torch_utils import randn_tensor
4040
from ...video_processor import VideoProcessor
41-
from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin
4241
from .pipeline_output import AllegroPipelineOutput
4342

4443

@@ -132,7 +131,7 @@ def retrieve_timesteps(
132131
return timesteps, num_inference_steps
133132

134133

135-
class AllegroPipeline(DiffusionPipeline, PyramidAttentionBroadcastMixin):
134+
class AllegroPipeline(DiffusionPipeline):
136135
r"""
137136
Pipeline for text-to-video generation using Allegro.
138137

src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from ...utils import logging, replace_example_docstring
3030
from ...utils.torch_utils import randn_tensor
3131
from ...video_processor import VideoProcessor
32-
from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin
3332
from .pipeline_output import CogVideoXPipelineOutput
3433

3534

@@ -138,7 +137,7 @@ def retrieve_timesteps(
138137
return timesteps, num_inference_steps
139138

140139

141-
class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin, PyramidAttentionBroadcastMixin):
140+
class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
142141
r"""
143142
Pipeline for text-to-video generation using CogVideoX.
144143

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from ...utils import logging, replace_example_docstring
3131
from ...utils.torch_utils import randn_tensor
3232
from ...video_processor import VideoProcessor
33-
from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin
3433
from .pipeline_output import CogVideoXPipelineOutput
3534

3635

@@ -145,7 +144,7 @@ def retrieve_timesteps(
145144
return timesteps, num_inference_steps
146145

147146

148-
class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin, PyramidAttentionBroadcastMixin):
147+
class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
149148
r"""
150149
Pipeline for controlled text-to-video generation using CogVideoX Fun.
151150

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
)
3535
from ...utils.torch_utils import randn_tensor
3636
from ...video_processor import VideoProcessor
37-
from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin
3837
from .pipeline_output import CogVideoXPipelineOutput
3938

4039

@@ -154,7 +153,7 @@ def retrieve_latents(
154153
raise AttributeError("Could not access latents of provided encoder_output")
155154

156155

157-
class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin, PyramidAttentionBroadcastMixin):
156+
class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
158157
r"""
159158
Pipeline for image-to-video generation using CogVideoX.
160159

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from ...utils import logging, replace_example_docstring
3131
from ...utils.torch_utils import randn_tensor
3232
from ...video_processor import VideoProcessor
33-
from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin
3433
from .pipeline_output import CogVideoXPipelineOutput
3534

3635

@@ -160,7 +159,7 @@ def retrieve_latents(
160159
raise AttributeError("Could not access latents of provided encoder_output")
161160

162161

163-
class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin, PyramidAttentionBroadcastMixin):
162+
class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
164163
r"""
165164
Pipeline for video-to-video generation using CogVideoX.
166165

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,7 @@ def __call__(
655655

656656
self._guidance_scale = guidance_scale
657657
self._joint_attention_kwargs = joint_attention_kwargs
658+
self._current_timestep = None
658659
self._interrupt = False
659660

660661
# 2. Define call parameters
@@ -731,6 +732,7 @@ def __call__(
731732
if self.interrupt:
732733
continue
733734

735+
self._current_timestep = t
734736
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
735737
timestep = t.expand(latents.shape[0]).to(latents.dtype)
736738

@@ -771,9 +773,10 @@ def __call__(
771773
if XLA_AVAILABLE:
772774
xm.mark_step()
773775

776+
self._current_timestep = None
777+
774778
if output_type == "latent":
775779
image = latents
776-
777780
else:
778781
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
779782
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor

src/diffusers/pipelines/latte/pipeline_latte.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
)
3838
from ...utils.torch_utils import is_compiled_module, randn_tensor
3939
from ...video_processor import VideoProcessor
40-
from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin
4140

4241

4342
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -133,7 +132,7 @@ class LattePipelineOutput(BaseOutput):
133132
frames: torch.Tensor
134133

135134

136-
class LattePipeline(DiffusionPipeline, PyramidAttentionBroadcastMixin):
135+
class LattePipeline(DiffusionPipeline):
137136
r"""
138137
Pipeline for text-to-video generation using Latte.
139138

0 commit comments

Comments
 (0)