Skip to content

Commit 07edfa9

Browse files
committed
update
1 parent 30d9aaf commit 07edfa9

File tree

4 files changed

+77
-62
lines changed

4 files changed

+77
-62
lines changed

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@
276276
"CogVideoXVideoToVideoPipeline",
277277
"CogView3PlusPipeline",
278278
"CycleDiffusionPipeline",
279+
"FasterCacheConfig",
279280
"FluxControlImg2ImgPipeline",
280281
"FluxControlInpaintPipeline",
281282
"FluxControlNetImg2ImgPipeline",
@@ -422,6 +423,7 @@
422423
"WuerstchenCombinedPipeline",
423424
"WuerstchenDecoderPipeline",
424425
"WuerstchenPriorPipeline",
426+
"apply_faster_cache",
425427
]
426428
)
427429

@@ -765,6 +767,7 @@
765767
CogVideoXVideoToVideoPipeline,
766768
CogView3PlusPipeline,
767769
CycleDiffusionPipeline,
770+
FasterCacheConfig,
768771
FluxControlImg2ImgPipeline,
769772
FluxControlInpaintPipeline,
770773
FluxControlNetImg2ImgPipeline,
@@ -909,6 +912,7 @@
909912
WuerstchenCombinedPipeline,
910913
WuerstchenDecoderPipeline,
911914
WuerstchenPriorPipeline,
915+
apply_faster_cache,
912916
)
913917

914918
try:

src/diffusers/models/embeddings.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,9 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
334334
" `from_numpy` is no longer required."
335335
" Pass `output_type='pt' to use the new version now."
336336
)
337-
# deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
337+
# TODO: Needs to be handled or errors out. Updated to 0.34.0 so that the benchmark code
338+
# runs without issues, but this should be handled properly before merge.
339+
deprecate("output_type=='np'", "0.34.0", deprecation_message, standard_warn=False)
338340
return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
339341
if embed_dim % 2 != 0:
340342
raise ValueError("embed_dim must be divisible by 2")

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
"StableDiffusionMixin",
5959
"ImagePipelineOutput",
6060
]
61+
_import_structure["faster_cache_utils"] = ["FasterCacheConfig", "apply_faster_cache"]
6162
_import_structure["deprecated"].extend(
6263
[
6364
"PNDMPipeline",
@@ -449,6 +450,7 @@
449450
from .ddpm import DDPMPipeline
450451
from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline
451452
from .dit import DiTPipeline
453+
from .faster_cache_utils import FasterCacheConfig, apply_faster_cache
452454
from .latent_diffusion import LDMSuperResolutionPipeline
453455
from .pipeline_utils import (
454456
AudioPipelineOutput,

src/diffusers/pipelines/faster_cache_utils.py

Lines changed: 68 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3030

3131

32+
# TODO(aryan): handle mochi attention
3233
_ATTENTION_CLASSES = (Attention,)
3334

3435
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
@@ -63,73 +64,75 @@ class FasterCacheConfig:
6364
states again.
6465
spatial_attention_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 681)`):
6566
The timestep range within which the spatial attention computation can be skipped without a significant loss
66-
in quality. This is to be determined by the user based on the underlying model. The first value in the tuple
67-
is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for denoising are
68-
in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at timestep 0). For the
69-
default values, this would mean that the spatial attention computation skipping will be applicable only
70-
after denoising timestep 681 is reached, and continue until the end of the denoising process.
67+
in quality. This is to be determined by the user based on the underlying model. The first value in the
68+
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
69+
denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
70+
timestep 0). For the default values, this would mean that the spatial attention computation skipping will
71+
be applicable only after denoising timestep 681 is reached, and continue until the end of the denoising
72+
process.
7173
temporal_attention_timestep_skip_range (`Tuple[float, float]`, *optional*, defaults to `None`):
72-
The timestep range within which the temporal attention computation can be skipped without a significant loss
73-
in quality. This is to be determined by the user based on the underlying model. The first value in the tuple
74-
is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for denoising are
75-
in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at timestep 0).
74+
The timestep range within which the temporal attention computation can be skipped without a significant
75+
loss in quality. This is to be determined by the user based on the underlying model. The first value in the
76+
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
77+
denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
78+
timestep 0).
7679
low_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(99, 901)`):
77-
The timestep range within which the low frequency weight scaling update is applied. The first value in the tuple
78-
is the lower bound and the second value is the upper bound of the timestep range. The callback function for
79-
the update is called only within this range.
80+
The timestep range within which the low frequency weight scaling update is applied. The first value in the
81+
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
82+
function for the update is called only within this range.
8083
high_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(-1, 301)`):
81-
The timestep range within which the high frequency weight scaling update is applied. The first value in the tuple
82-
is the lower bound and the second value is the upper bound of the timestep range. The callback function for
83-
the update is called only within this range.
84+
The timestep range within which the high frequency weight scaling update is applied. The first value in the
85+
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
86+
function for the update is called only within this range.
8487
alpha_low_frequency (`float`, defaults to `1.1`):
8588
The weight to scale the low frequency updates by. This is used to approximate the unconditional branch from
8689
the conditional branch outputs.
8790
alpha_high_frequency (`float`, defaults to `1.1`):
88-
The weight to scale the high frequency updates by. This is used to approximate the unconditional branch from
89-
the conditional branch outputs.
91+
The weight to scale the high frequency updates by. This is used to approximate the unconditional branch
92+
from the conditional branch outputs.
9093
unconditional_batch_skip_range (`int`, defaults to `5`):
9194
Process the unconditional branch every `N` iterations. If this is set to `N`, the unconditional branch
9295
computation will be skipped `N - 1` times (i.e., cached unconditional branch states will be re-used) before
9396
computing the new unconditional branch states again.
9497
unconditional_batch_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 641)`):
95-
The timestep range within which the unconditional branch computation can be skipped without a significant loss
96-
in quality. This is to be determined by the user based on the underlying model. The first value in the tuple
97-
is the lower bound and the second value is the upper bound.
98+
The timestep range within which the unconditional branch computation can be skipped without a significant
99+
loss in quality. This is to be determined by the user based on the underlying model. The first value in the
100+
tuple is the lower bound and the second value is the upper bound.
98101
spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")`):
99-
The identifiers to match the spatial attention blocks in the model. If the name of the block contains any of
100-
these identifiers, FasterCache will be applied to that block. This can either be the full layer names, partial
101-
layer names, or regex patterns. Matching will always be done using a regex match.
102+
The identifiers to match the spatial attention blocks in the model. If the name of the block contains any
103+
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
104+
partial layer names, or regex patterns. Matching will always be done using a regex match.
102105
temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks.*attn1",)`):
103-
The identifiers to match the temporal attention blocks in the model. If the name of the block contains any of
104-
these identifiers, FasterCache will be applied to that block. This can either be the full layer names, partial
105-
layer names, or regex patterns. Matching will always be done using a regex match.
106+
The identifiers to match the temporal attention blocks in the model. If the name of the block contains any
107+
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
108+
partial layer names, or regex patterns. Matching will always be done using a regex match.
106109
attention_weight_callback (`Callable[[nn.Module], float]`, defaults to `None`):
107-
The callback function to determine the weight to scale the attention outputs by. This function should take the
108-
attention module as input and return a float value. This is used to approximate the unconditional branch from
109-
the conditional branch outputs. If not provided, the default weight is 0.5 for all timesteps. Typically, as
110-
described in the paper, this weight should gradually increase from 0 to 1 as the inference progresses. Users
111-
are encouraged to experiment and provide custom weight schedules that take into account the number of inference
112-
steps and underlying model behaviour as denoising progresses.
110+
The callback function to determine the weight to scale the attention outputs by. This function should take
111+
the attention module as input and return a float value. This is used to approximate the unconditional
112+
branch from the conditional branch outputs. If not provided, the default weight is 0.5 for all timesteps.
113+
Typically, as described in the paper, this weight should gradually increase from 0 to 1 as the inference
114+
progresses. Users are encouraged to experiment and provide custom weight schedules that take into account
115+
the number of inference steps and underlying model behaviour as denoising progresses.
113116
low_frequency_weight_callback (`Callable[[nn.Module], float]`, defaults to `None`):
114-
The callback function to determine the weight to scale the low frequency updates by. If not provided, the default
115-
weight is 1.1 for timesteps within the range specified (as described in the paper).
117+
The callback function to determine the weight to scale the low frequency updates by. If not provided, the
118+
default weight is 1.1 for timesteps within the range specified (as described in the paper).
116119
high_frequency_weight_callback (`Callable[[nn.Module], float]`, defaults to `None`):
117-
The callback function to determine the weight to scale the high frequency updates by. If not provided, the default
118-
weight is 1.1 for timesteps within the range specified (as described in the paper).
120+
The callback function to determine the weight to scale the high frequency updates by. If not provided, the
121+
default weight is 1.1 for timesteps within the range specified (as described in the paper).
119122
tensor_format (`str`, defaults to `"BCFHW"`):
120-
The format of the input tensors. This should be one of `"BCFHW"`, `"BFCHW"`, or `"BCHW"`. The format is used to
121-
split individual latent frames in order for low and high frequency components to be computed.
123+
The format of the input tensors. This should be one of `"BCFHW"`, `"BFCHW"`, or `"BCHW"`. The format is
124+
used to split individual latent frames in order for low and high frequency components to be computed.
122125
_unconditional_conditional_input_kwargs_identifiers (`List[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`):
123-
The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and conditional
124-
inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will split the inputs
125-
into unconditional and conditional branches. This must be a list of exact input kwargs names that contain the
126-
batchwise-concatenated unconditional and conditional inputs.
126+
The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and
127+
conditional inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will
128+
split the inputs into unconditional and conditional branches. This must be a list of exact input kwargs
129+
names that contain the batchwise-concatenated unconditional and conditional inputs.
127130
_guidance_distillation_kwargs_identifiers (`List[str]`, defaults to `("guidance",)`):
128-
The identifiers to match the input kwargs that contain the guidance distillation inputs. If the name of the input
129-
kwargs contains any of these identifiers, FasterCache will not split the inputs into unconditional and conditional
130-
branches (unconditional branches are only computed sometimes based on certain checks). This allows usage of
131-
FasterCache in models like Flux-Dev and HunyuanVideo which are guidance-distilled (only attention skipping
132-
related parts are applied, and not unconditional branch approximation).
131+
The identifiers to match the input kwargs that contain the guidance distillation inputs. If the name of the
132+
input kwargs contains any of these identifiers, FasterCache will not split the inputs into unconditional
133+
and conditional branches (unconditional branches are only computed sometimes based on certain checks). This
134+
allows usage of FasterCache in models like Flux-Dev and HunyuanVideo which are guidance-distilled (only
135+
attention skipping related parts are applied, and not unconditional branch approximation).
133136
"""
134137

135138
# In the paper and codebase, they hardcode these values to 2. However, it can be made configurable
@@ -225,7 +228,6 @@ def reset(self):
225228
def apply_faster_cache(
226229
pipeline: DiffusionPipeline,
227230
config: Optional[FasterCacheConfig] = None,
228-
denoiser: Optional[nn.Module] = None,
229231
) -> None:
230232
r"""
231233
Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
@@ -238,9 +240,6 @@ def apply_faster_cache(
238240
The diffusion pipeline to apply FasterCache to.
239241
config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`):
240242
The configuration to use for FasterCache.
241-
denoiser (`Optional[nn.Module]`, `optional`, defaults to `None`):
242-
The denoiser module to apply FasterCache to. If `None`, the pipeline's transformer or unet module will be
243-
used.
244243
245244
Example:
246245
```python
@@ -310,8 +309,7 @@ def high_frequency_weight_callback(module: nn.Module) -> float:
310309
if config.tensor_format not in supported_tensor_formats:
311310
raise ValueError(f"`tensor_format` must be one of {supported_tensor_formats}, but got {config.tensor_format}.")
312311

313-
if denoiser is None:
314-
denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet
312+
denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet
315313
_apply_fastercache_on_denoiser(pipeline, denoiser, config)
316314

317315
for name, module in denoiser.named_modules():
@@ -344,7 +342,11 @@ def uncond_skip_callback(module: nn.Module) -> bool:
344342
denoiser._fastercache_state = FasterCacheDenoiserState(
345343
config.low_frequency_weight_callback, config.high_frequency_weight_callback, uncond_skip_callback
346344
)
347-
hook = FasterCacheDenoiserHook(config._unconditional_conditional_input_kwargs_identifiers, config._guidance_distillation_kwargs_identifiers, config.tensor_format)
345+
hook = FasterCacheDenoiserHook(
346+
config._unconditional_conditional_input_kwargs_identifiers,
347+
config._guidance_distillation_kwargs_identifiers,
348+
config.tensor_format,
349+
)
348350
add_hook_to_module(denoiser, hook, append=True)
349351

350352

@@ -408,11 +410,12 @@ def skip_callback(module: nn.Module) -> bool:
408410
class FasterCacheDenoiserHook(ModelHook):
409411
_is_stateful = True
410412

411-
def __init__(self,
412-
uncond_cond_input_kwargs_identifiers: List[str],
413-
guidance_distillation_kwargs_identifiers: List[str],
414-
tensor_format: str
415-
) -> None:
413+
def __init__(
414+
self,
415+
uncond_cond_input_kwargs_identifiers: List[str],
416+
guidance_distillation_kwargs_identifiers: List[str],
417+
tensor_format: str,
418+
) -> None:
416419
super().__init__()
417420

418421
# We can't easily detect what args are to be split in unconditional and conditional branches. We
@@ -451,7 +454,9 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any:
451454
# Make all children FasterCacheBlockHooks aware of whether the model is guidance distilled or not
452455
# because we cannot determine this within the block hooks
453456
for name, child_module in module.named_modules():
454-
if hasattr(child_module, "_fastercache_state") and isinstance(child_module._fastercache_state, FasterCacheBlockState):
457+
if hasattr(child_module, "_fastercache_state") and isinstance(
458+
child_module._fastercache_state, FasterCacheBlockState
459+
):
455460
# TODO(aryan): remove later
456461
logger.debug(f"Setting guidance distillation flag for layer: {name}")
457462
child_module._fastercache_state.is_guidance_distilled = state.is_guidance_distilled
@@ -570,7 +575,9 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any:
570575
# the cache (which only caches conditional branch outputs). So, if state.batch_size (which is the true
571576
# unconditional-conditional batch size) is same as the current batch size, we don't perform the layer
572577
# skip. Otherwise, we conditionally skip the layer based on what state.skip_callback returns.
573-
should_skip_attention = state.skip_callback(module) and (state.is_guidance_distilled or state.batch_size != batch_size)
578+
should_skip_attention = state.skip_callback(module) and (
579+
state.is_guidance_distilled or state.batch_size != batch_size
580+
)
574581

575582
if should_skip_attention:
576583
# TODO(aryan): remove later

0 commit comments

Comments
 (0)