Skip to content

Commit 995b82f

Browse files
committed
update
1 parent d95d61a commit 995b82f

File tree

11 files changed

+141
-33
lines changed

11 files changed

+141
-33
lines changed

src/diffusers/models/hooks.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -117,45 +117,72 @@ def reset_state(self, module):
117117
class PyramidAttentionBroadcastHook(ModelHook):
118118
def __init__(
119119
self,
120-
skip_range: int,
121-
timestep_range: Tuple[int, int],
122-
timestep_callback: Callable[[], Union[torch.LongTensor, int]],
120+
skip_callback: Callable[[torch.nn.Module], bool],
121+
# skip_range: int,
122+
# timestep_range: Tuple[int, int],
123+
# timestep_callback: Callable[[], Union[torch.LongTensor, int]],
123124
) -> None:
124125
super().__init__()
125126

126-
self.skip_range = skip_range
127-
self.timestep_range = timestep_range
128-
self.timestep_callback = timestep_callback
127+
# self.skip_range = skip_range
128+
# self.timestep_range = timestep_range
129+
# self.timestep_callback = timestep_callback
130+
self.skip_callback = skip_callback
129131

130-
self.attention_cache = None
132+
self.cache = None
131133
self._iteration = 0
132134

133135
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
134136
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
135137

136-
current_timestep = self.timestep_callback()
137-
is_within_timestep_range = self.timestep_range[0] < current_timestep < self.timestep_range[1]
138-
should_compute_attention = self._iteration % self.skip_range == 0
138+
# current_timestep = self.timestep_callback()
139+
# is_within_timestep_range = self.timestep_range[0] < current_timestep < self.timestep_range[1]
140+
# should_compute_attention = self._iteration % self.skip_range == 0
139141

140-
if not is_within_timestep_range or should_compute_attention:
141-
output = module._old_forward(*args, **kwargs)
142-
else:
143-
output = self.attention_cache
142+
# if not is_within_timestep_range or should_compute_attention:
143+
# output = module._old_forward(*args, **kwargs)
144+
# else:
145+
# output = self.attention_cache
144146

145-
self._iteration = self._iteration + 1
147+
if self.cache is not None and self.skip_callback(module):
148+
output = self.cache
149+
else:
150+
output = module._old_forward(*args, **kwargs)
146151

147152
return module._diffusers_hook.post_forward(module, output)
148153

149154
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
150-
self.attention_cache = output
155+
self.cache = output
151156
return output
152157

153158
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
154-
self.attention_cache = None
159+
self.cache = None
155160
self._iteration = 0
156161
return module
157162

158163

164+
class LayerSkipHook(ModelHook):
165+
def __init__(self, skip_: Callable[[torch.nn.Module], bool]) -> None:
166+
super().__init__()
167+
168+
self.skip_callback = skip_
169+
170+
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
171+
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
172+
173+
if self.skip_callback(module):
174+
# We want to skip this layer, so we have to return the input of the current layer
175+
# as output of the next layer. But at this point, we don't have information about
176+
# the arguments required by next layer. Even if we did, order matters unless we
177+
# always pass kwargs. But that is not the case usually with hidden_states, encoder_hidden_states,
178+
# temb, etc. TODO(aryan): implement correctly later
179+
output = None
180+
else:
181+
output = module._old_forward(*args, **kwargs)
182+
183+
return module._diffusers_hook.post_forward(module, output)
184+
185+
159186
def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False):
160187
r"""
161188
Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove

src/diffusers/pipelines/allegro/pipeline_allegro.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
)
3939
from ...utils.torch_utils import randn_tensor
4040
from ...video_processor import VideoProcessor
41-
from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin
4241
from .pipeline_output import AllegroPipelineOutput
4342

4443

@@ -132,7 +131,7 @@ def retrieve_timesteps(
132131
return timesteps, num_inference_steps
133132

134133

135-
class AllegroPipeline(DiffusionPipeline, PyramidAttentionBroadcastMixin):
134+
class AllegroPipeline(DiffusionPipeline):
136135
r"""
137136
Pipeline for text-to-video generation using Allegro.
138137

src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from ...utils import logging, replace_example_docstring
3030
from ...utils.torch_utils import randn_tensor
3131
from ...video_processor import VideoProcessor
32-
from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin
3332
from .pipeline_output import CogVideoXPipelineOutput
3433

3534

@@ -138,7 +137,7 @@ def retrieve_timesteps(
138137
return timesteps, num_inference_steps
139138

140139

141-
class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin, PyramidAttentionBroadcastMixin):
140+
class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
142141
r"""
143142
Pipeline for text-to-video generation using CogVideoX.
144143

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from ...utils import logging, replace_example_docstring
3131
from ...utils.torch_utils import randn_tensor
3232
from ...video_processor import VideoProcessor
33-
from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin
3433
from .pipeline_output import CogVideoXPipelineOutput
3534

3635

@@ -145,7 +144,7 @@ def retrieve_timesteps(
145144
return timesteps, num_inference_steps
146145

147146

148-
class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin, PyramidAttentionBroadcastMixin):
147+
class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
149148
r"""
150149
Pipeline for controlled text-to-video generation using CogVideoX Fun.
151150

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
)
3535
from ...utils.torch_utils import randn_tensor
3636
from ...video_processor import VideoProcessor
37-
from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin
3837
from .pipeline_output import CogVideoXPipelineOutput
3938

4039

@@ -154,7 +153,7 @@ def retrieve_latents(
154153
raise AttributeError("Could not access latents of provided encoder_output")
155154

156155

157-
class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin, PyramidAttentionBroadcastMixin):
156+
class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
158157
r"""
159158
Pipeline for image-to-video generation using CogVideoX.
160159

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from ...utils import logging, replace_example_docstring
3131
from ...utils.torch_utils import randn_tensor
3232
from ...video_processor import VideoProcessor
33-
from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin
3433
from .pipeline_output import CogVideoXPipelineOutput
3534

3635

@@ -160,7 +159,7 @@ def retrieve_latents(
160159
raise AttributeError("Could not access latents of provided encoder_output")
161160

162161

163-
class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin, PyramidAttentionBroadcastMixin):
162+
class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
164163
r"""
165164
Pipeline for video-to-video generation using CogVideoX.
166165

src/diffusers/pipelines/latte/pipeline_latte.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
)
3838
from ...utils.torch_utils import is_compiled_module, randn_tensor
3939
from ...video_processor import VideoProcessor
40-
from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin
4140

4241

4342
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -133,7 +132,7 @@ class LattePipelineOutput(BaseOutput):
133132
frames: torch.Tensor
134133

135134

136-
class LattePipeline(DiffusionPipeline, PyramidAttentionBroadcastMixin):
135+
class LattePipeline(DiffusionPipeline):
137136
r"""
138137
Pipeline for text-to-video generation using Latte.
139138

src/diffusers/pipelines/pyramid_broadcast_utils.py renamed to src/diffusers/pipelines/pyramid_attention_broadcast_utils.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,105 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import List, Optional, Tuple
15+
from dataclasses import dataclass
16+
from typing import Callable, List, Optional, Tuple, Type, TypeVar
1617

1718
import torch.nn as nn
1819

1920
from ..models.attention_processor import Attention
2021
from ..models.hooks import PyramidAttentionBroadcastHook, add_hook_to_module, remove_hook_from_module
2122
from ..utils import logging
23+
from .pipeline_utils import DiffusionPipeline
2224

2325

2426
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2527

2628

29+
_ATTENTION_CLASSES = (Attention,)
30+
31+
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ["blocks", "transformer_blocks"]
32+
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ["temporal_transformer_blocks"]
33+
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ["blocks", "transformer_blocks"]
34+
35+
36+
@dataclass
37+
class PyramidAttentionBroadcastConfig:
38+
spatial_attention_block_skip = None
39+
temporal_attention_block_skip = None
40+
cross_attention_block_skip = None
41+
42+
spatial_attention_timestep_skip_range = (100, 800)
43+
temporal_attention_timestep_skip_range = (100, 800)
44+
cross_attention_timestep_skip_range = (100, 800)
45+
46+
spatial_attention_block_identifiers = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
47+
temporal_attention_block_identifiers = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
48+
cross_attention_block_identifiers = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
49+
50+
51+
class PyramidAttentionBroadcastState:
52+
iteration = 0
53+
54+
55+
def apply_pyramid_attention_broadcast(
56+
pipeline: DiffusionPipeline,
57+
config: Optional[PyramidAttentionBroadcastConfig] = None,
58+
denoiser: Optional[nn.Module] = None,
59+
):
60+
if config is None:
61+
config = PyramidAttentionBroadcastConfig()
62+
63+
if config.spatial_attention_block_skip is None and config.temporal_attention_block_skip is None and config.cross_attention_block_skip is None:
64+
logger.warning(
65+
"Pyramid Attention Broadcast requires one or more of `spatial_attention_block_skip`, `temporal_attention_block_skip` "
66+
"or `cross_attention_block_skip` parameters to be set to an integer, not `None`. Defaulting to using `spatial_attention_block_skip=2`. "
67+
"To avoid this warning, please set one of the above parameters."
68+
)
69+
config.spatial_attention_block_skip = 2
70+
71+
if denoiser is None:
72+
denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet
73+
74+
for name, module in denoiser.named_modules():
75+
if not isinstance(module, _ATTENTION_CLASSES):
76+
continue
77+
if isinstance(module, Attention):
78+
_apply_pyramid_attention_broadcast_on_attention_class(pipeline, name, module, config)
79+
80+
81+
# def apply_pyramid_attention_broadcast_spatial(module: TypeVar[_ATTENTION_CLASSES], config: PyramidAttentionBroadcastConfig):
82+
# hook = PyramidAttentionBroadcastHook(skip_callback=)
83+
# add_hook_to_module(module)
84+
85+
86+
def _apply_pyramid_attention_broadcast_on_attention_class(pipeline: DiffusionPipeline, name: str, module: Attention, config: PyramidAttentionBroadcastConfig):
87+
# Similar check as PEFT to determine if a string layer name matches a module name
88+
is_spatial_self_attention = (
89+
any(f"{identifier}." in name or identifier == name for identifier in config.spatial_attention_block_identifiers)
90+
and config.spatial_attention_timestep_skip_range is not None
91+
and not module.is_cross_attention
92+
)
93+
is_temporal_self_attention = (
94+
any(f"{identifier}." in name or identifier == name for identifier in config.temporal_attention_block_identifiers)
95+
and config.temporal_attention_timestep_skip_range is not None
96+
and not module.is_cross_attention
97+
)
98+
is_cross_attention = (
99+
any(f"{identifier}." in name or identifier == name for identifier in config.cross_attention_block_identifiers)
100+
and config.cross_attention_timestep_skip_range is not None
101+
and not module.is_cross_attention
102+
)
103+
104+
if is_spatial_self_attention:
105+
apply_pyramid_attention_broadcast_spatial(module, config)
106+
elif is_temporal_self_attention:
107+
apply_pyramid_attention_broadcast_temporal(module, config)
108+
elif is_cross_attention:
109+
apply_pyramid_attention_broadcast_cross(module, config)
110+
else:
111+
logger.warning(f"Unable to apply Pyramid Attention Broadcast to the selected layer: {name}.")
112+
113+
27114
class PyramidAttentionBroadcastMixin:
28115
r"""Mixin class for [Pyramid Attention Broadcast](https://www.arxiv.org/abs/2408.12588)."""
29116

tests/pipelines/cogvideo/test_cogvideox.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from transformers import AutoTokenizer, T5EncoderModel
2222

2323
from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, DDIMScheduler
24-
from diffusers.pipelines.pyramid_broadcast_utils import PyramidAttentionBroadcastAttentionProcessorWrapper
24+
from diffusers.pipelines.pyramid_attention_broadcast_utils import PyramidAttentionBroadcastAttentionProcessorWrapper
2525
from diffusers.utils.testing_utils import (
2626
enable_full_determinism,
2727
numpy_cosine_similarity_distance,

tests/pipelines/cogvideo/test_cogvideox_image2video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from transformers import AutoTokenizer, T5EncoderModel
2323

2424
from diffusers import AutoencoderKLCogVideoX, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel, DDIMScheduler
25-
from diffusers.pipelines.pyramid_broadcast_utils import PyramidAttentionBroadcastAttentionProcessorWrapper
25+
from diffusers.pipelines.pyramid_attention_broadcast_utils import PyramidAttentionBroadcastAttentionProcessorWrapper
2626
from diffusers.utils import load_image
2727
from diffusers.utils.testing_utils import (
2828
enable_full_determinism,

0 commit comments

Comments
 (0)