Skip to content

Commit 04874a7

Browse files
committed
reorder
1 parent 4996dfd commit 04874a7

File tree

1 file changed

+180
-180
lines changed

1 file changed

+180
-180
lines changed

src/diffusers/pipelines/fastercache_utils.py

Lines changed: 180 additions & 180 deletions
Original file line numberDiff line numberDiff line change
@@ -225,186 +225,6 @@ def reset(self):
225225
self.is_guidance_distilled = None
226226

227227

228-
def apply_fastercache(
229-
pipeline: DiffusionPipeline,
230-
config: Optional[FasterCacheConfig] = None,
231-
) -> None:
232-
r"""
233-
Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
234-
235-
Args:
236-
pipeline (`DiffusionPipeline`):
237-
The diffusion pipeline to apply FasterCache to.
238-
config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`):
239-
The configuration to use for FasterCache.
240-
241-
Example:
242-
```python
243-
>>> import torch
244-
>>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_fastercache
245-
246-
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
247-
>>> pipe.to("cuda")
248-
249-
>>> config = FasterCacheConfig(
250-
... spatial_attention_block_skip_range=2,
251-
... spatial_attention_timestep_skip_range=(-1, 681),
252-
... low_frequency_weight_update_timestep_range=(99, 641),
253-
... high_frequency_weight_update_timestep_range=(-1, 301),
254-
... spatial_attention_block_identifiers=["transformer_blocks"],
255-
... attention_weight_callback=lambda _: 0.3,
256-
... tensor_format="BFCHW",
257-
... )
258-
>>> apply_fastercache(pipe, config)
259-
```
260-
"""
261-
262-
if config is None:
263-
logger.warning("No FasterCacheConfig provided. Using default configuration.")
264-
config = FasterCacheConfig()
265-
266-
if config.attention_weight_callback is None:
267-
# If the user has not provided a weight callback, we default to 0.5 for all timesteps.
268-
# In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but
269-
# this depends from model-to-model. It is required by the user to provide a weight callback if they want to
270-
# use a different weight function. Defaulting to 0.5 works well in practice for most cases.
271-
logger.warning(
272-
"No `attention_weight_callback` provided when enabling FasterCache. Defaulting to using a weight of 0.5 for all timesteps."
273-
)
274-
config.attention_weight_callback = lambda _: 0.5
275-
276-
if config.low_frequency_weight_callback is None:
277-
logger.debug(
278-
"Low frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
279-
)
280-
281-
def low_frequency_weight_callback(module: nn.Module) -> float:
282-
is_within_range = (
283-
config.low_frequency_weight_update_timestep_range[0]
284-
< pipeline._current_timestep
285-
< config.low_frequency_weight_update_timestep_range[1]
286-
)
287-
return config.alpha_low_frequency if is_within_range else 1.0
288-
289-
config.low_frequency_weight_callback = low_frequency_weight_callback
290-
291-
if config.high_frequency_weight_callback is None:
292-
logger.debug(
293-
"High frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
294-
)
295-
296-
def high_frequency_weight_callback(module: nn.Module) -> float:
297-
is_within_range = (
298-
config.high_frequency_weight_update_timestep_range[0]
299-
< pipeline._current_timestep
300-
< config.high_frequency_weight_update_timestep_range[1]
301-
)
302-
return config.alpha_high_frequency if is_within_range else 1.0
303-
304-
config.high_frequency_weight_callback = high_frequency_weight_callback
305-
306-
supported_tensor_formats = ["BCFHW", "BFCHW", "BCHW"] # TODO(aryan): Support BSC for LTX Video
307-
if config.tensor_format not in supported_tensor_formats:
308-
raise ValueError(f"`tensor_format` must be one of {supported_tensor_formats}, but got {config.tensor_format}.")
309-
310-
denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet
311-
_apply_fastercache_on_denoiser(pipeline, denoiser, config)
312-
313-
for name, module in denoiser.named_modules():
314-
if not isinstance(module, _ATTENTION_CLASSES):
315-
continue
316-
if isinstance(module, Attention):
317-
_apply_fastercache_on_attention_class(pipeline, name, module, config)
318-
319-
320-
def _apply_fastercache_on_denoiser(
321-
pipeline: DiffusionPipeline, denoiser: nn.Module, config: FasterCacheConfig
322-
) -> None:
323-
def uncond_skip_callback(module: nn.Module) -> bool:
324-
# We skip the unconditional branch only if the following conditions are met:
325-
# 1. We have completed at least one iteration of the denoiser
326-
# 2. The current timestep is within the range specified by the user. This is the optimal timestep range
327-
# where approximating the unconditional branch from the computation of the conditional branch is possible
328-
# without a significant loss in quality.
329-
# 3. The current iteration is not a multiple of the unconditional batch skip range. This is done so that
330-
# we compute the unconditional branch at least once every few iterations to ensure minimal quality loss.
331-
332-
state: FasterCacheDenoiserState = module._fastercache_state
333-
is_within_range = (
334-
config.unconditional_batch_timestep_skip_range[0]
335-
< pipeline._current_timestep
336-
< config.unconditional_batch_timestep_skip_range[1]
337-
)
338-
return state.iteration > 0 and is_within_range and state.iteration % config.unconditional_batch_skip_range != 0
339-
340-
denoiser._fastercache_state = FasterCacheDenoiserState(
341-
config.low_frequency_weight_callback, config.high_frequency_weight_callback, uncond_skip_callback
342-
)
343-
hook = FasterCacheDenoiserHook(
344-
config._unconditional_conditional_input_kwargs_identifiers,
345-
config._guidance_distillation_kwargs_identifiers,
346-
config.tensor_format,
347-
)
348-
add_hook_to_module(denoiser, hook, append=True)
349-
350-
351-
def _apply_fastercache_on_attention_class(
352-
pipeline: DiffusionPipeline, name: str, module: Attention, config: FasterCacheConfig
353-
) -> None:
354-
is_spatial_self_attention = (
355-
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
356-
and config.spatial_attention_block_skip_range is not None
357-
and not module.is_cross_attention
358-
)
359-
is_temporal_self_attention = (
360-
any(
361-
f"{identifier}." in name or identifier == name
362-
for identifier in config.temporal_attention_block_identifiers
363-
)
364-
and config.temporal_attention_block_skip_range is not None
365-
and not module.is_cross_attention
366-
)
367-
368-
block_skip_range, timestep_skip_range, block_type = None, None, None
369-
if is_spatial_self_attention:
370-
block_skip_range = config.spatial_attention_block_skip_range
371-
timestep_skip_range = config.spatial_attention_timestep_skip_range
372-
block_type = "spatial"
373-
elif is_temporal_self_attention:
374-
block_skip_range = config.temporal_attention_block_skip_range
375-
timestep_skip_range = config.temporal_attention_timestep_skip_range
376-
block_type = "temporal"
377-
378-
if block_skip_range is None or timestep_skip_range is None:
379-
logger.debug(
380-
f'Unable to apply FasterCache to the selected layer: "{name}" because it does '
381-
f"not match any of the required criteria for spatial or temporal attention layers. Note, "
382-
f"however, that this layer may still be valid for applying PAB. Please specify the correct "
383-
f"block identifiers in the configuration or use the specialized `apply_fastercache_on_module` "
384-
f"function to apply FasterCache to this layer."
385-
)
386-
return
387-
388-
def skip_callback(module: nn.Module) -> bool:
389-
fastercache_state: FasterCacheBlockState = module._fastercache_state
390-
is_within_timestep_range = timestep_skip_range[0] < pipeline._current_timestep < timestep_skip_range[1]
391-
392-
if not is_within_timestep_range:
393-
# We are still not in the phase of inference where skipping attention is possible without minimal quality
394-
# loss, as described in the paper. So, the attention computation cannot be skipped
395-
return False
396-
397-
should_compute_attention = (
398-
fastercache_state.iteration > 0 and fastercache_state.iteration % block_skip_range == 0
399-
)
400-
return not should_compute_attention
401-
402-
logger.debug(f"Enabling FasterCache ({block_type}) for layer: {name}")
403-
module._fastercache_state = FasterCacheBlockState(skip_callback, config.attention_weight_callback)
404-
hook = FasterCacheBlockHook()
405-
add_hook_to_module(module, hook, append=True)
406-
407-
408228
class FasterCacheDenoiserHook(ModelHook):
409229
_is_stateful = True
410230

@@ -632,6 +452,186 @@ def reset_state(self, module: nn.Module) -> nn.Module:
632452
return module
633453

634454

455+
def apply_fastercache(
456+
pipeline: DiffusionPipeline,
457+
config: Optional[FasterCacheConfig] = None,
458+
) -> None:
459+
r"""
460+
Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
461+
462+
Args:
463+
pipeline (`DiffusionPipeline`):
464+
The diffusion pipeline to apply FasterCache to.
465+
config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`):
466+
The configuration to use for FasterCache.
467+
468+
Example:
469+
```python
470+
>>> import torch
471+
>>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_fastercache
472+
473+
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
474+
>>> pipe.to("cuda")
475+
476+
>>> config = FasterCacheConfig(
477+
... spatial_attention_block_skip_range=2,
478+
... spatial_attention_timestep_skip_range=(-1, 681),
479+
... low_frequency_weight_update_timestep_range=(99, 641),
480+
... high_frequency_weight_update_timestep_range=(-1, 301),
481+
... spatial_attention_block_identifiers=["transformer_blocks"],
482+
... attention_weight_callback=lambda _: 0.3,
483+
... tensor_format="BFCHW",
484+
... )
485+
>>> apply_fastercache(pipe, config)
486+
```
487+
"""
488+
489+
if config is None:
490+
logger.warning("No FasterCacheConfig provided. Using default configuration.")
491+
config = FasterCacheConfig()
492+
493+
if config.attention_weight_callback is None:
494+
# If the user has not provided a weight callback, we default to 0.5 for all timesteps.
495+
# In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but
496+
# this depends from model-to-model. It is required by the user to provide a weight callback if they want to
497+
# use a different weight function. Defaulting to 0.5 works well in practice for most cases.
498+
logger.warning(
499+
"No `attention_weight_callback` provided when enabling FasterCache. Defaulting to using a weight of 0.5 for all timesteps."
500+
)
501+
config.attention_weight_callback = lambda _: 0.5
502+
503+
if config.low_frequency_weight_callback is None:
504+
logger.debug(
505+
"Low frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
506+
)
507+
508+
def low_frequency_weight_callback(module: nn.Module) -> float:
509+
is_within_range = (
510+
config.low_frequency_weight_update_timestep_range[0]
511+
< pipeline._current_timestep
512+
< config.low_frequency_weight_update_timestep_range[1]
513+
)
514+
return config.alpha_low_frequency if is_within_range else 1.0
515+
516+
config.low_frequency_weight_callback = low_frequency_weight_callback
517+
518+
if config.high_frequency_weight_callback is None:
519+
logger.debug(
520+
"High frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
521+
)
522+
523+
def high_frequency_weight_callback(module: nn.Module) -> float:
524+
is_within_range = (
525+
config.high_frequency_weight_update_timestep_range[0]
526+
< pipeline._current_timestep
527+
< config.high_frequency_weight_update_timestep_range[1]
528+
)
529+
return config.alpha_high_frequency if is_within_range else 1.0
530+
531+
config.high_frequency_weight_callback = high_frequency_weight_callback
532+
533+
supported_tensor_formats = ["BCFHW", "BFCHW", "BCHW"] # TODO(aryan): Support BSC for LTX Video
534+
if config.tensor_format not in supported_tensor_formats:
535+
raise ValueError(f"`tensor_format` must be one of {supported_tensor_formats}, but got {config.tensor_format}.")
536+
537+
denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet
538+
_apply_fastercache_on_denoiser(pipeline, denoiser, config)
539+
540+
for name, module in denoiser.named_modules():
541+
if not isinstance(module, _ATTENTION_CLASSES):
542+
continue
543+
if isinstance(module, Attention):
544+
_apply_fastercache_on_attention_class(pipeline, name, module, config)
545+
546+
547+
def _apply_fastercache_on_denoiser(
548+
pipeline: DiffusionPipeline, denoiser: nn.Module, config: FasterCacheConfig
549+
) -> None:
550+
def uncond_skip_callback(module: nn.Module) -> bool:
551+
# We skip the unconditional branch only if the following conditions are met:
552+
# 1. We have completed at least one iteration of the denoiser
553+
# 2. The current timestep is within the range specified by the user. This is the optimal timestep range
554+
# where approximating the unconditional branch from the computation of the conditional branch is possible
555+
# without a significant loss in quality.
556+
# 3. The current iteration is not a multiple of the unconditional batch skip range. This is done so that
557+
# we compute the unconditional branch at least once every few iterations to ensure minimal quality loss.
558+
559+
state: FasterCacheDenoiserState = module._fastercache_state
560+
is_within_range = (
561+
config.unconditional_batch_timestep_skip_range[0]
562+
< pipeline._current_timestep
563+
< config.unconditional_batch_timestep_skip_range[1]
564+
)
565+
return state.iteration > 0 and is_within_range and state.iteration % config.unconditional_batch_skip_range != 0
566+
567+
denoiser._fastercache_state = FasterCacheDenoiserState(
568+
config.low_frequency_weight_callback, config.high_frequency_weight_callback, uncond_skip_callback
569+
)
570+
hook = FasterCacheDenoiserHook(
571+
config._unconditional_conditional_input_kwargs_identifiers,
572+
config._guidance_distillation_kwargs_identifiers,
573+
config.tensor_format,
574+
)
575+
add_hook_to_module(denoiser, hook, append=True)
576+
577+
578+
def _apply_fastercache_on_attention_class(
579+
pipeline: DiffusionPipeline, name: str, module: Attention, config: FasterCacheConfig
580+
) -> None:
581+
is_spatial_self_attention = (
582+
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
583+
and config.spatial_attention_block_skip_range is not None
584+
and not module.is_cross_attention
585+
)
586+
is_temporal_self_attention = (
587+
any(
588+
f"{identifier}." in name or identifier == name
589+
for identifier in config.temporal_attention_block_identifiers
590+
)
591+
and config.temporal_attention_block_skip_range is not None
592+
and not module.is_cross_attention
593+
)
594+
595+
block_skip_range, timestep_skip_range, block_type = None, None, None
596+
if is_spatial_self_attention:
597+
block_skip_range = config.spatial_attention_block_skip_range
598+
timestep_skip_range = config.spatial_attention_timestep_skip_range
599+
block_type = "spatial"
600+
elif is_temporal_self_attention:
601+
block_skip_range = config.temporal_attention_block_skip_range
602+
timestep_skip_range = config.temporal_attention_timestep_skip_range
603+
block_type = "temporal"
604+
605+
if block_skip_range is None or timestep_skip_range is None:
606+
logger.debug(
607+
f'Unable to apply FasterCache to the selected layer: "{name}" because it does '
608+
f"not match any of the required criteria for spatial or temporal attention layers. Note, "
609+
f"however, that this layer may still be valid for applying PAB. Please specify the correct "
610+
f"block identifiers in the configuration or use the specialized `apply_fastercache_on_module` "
611+
f"function to apply FasterCache to this layer."
612+
)
613+
return
614+
615+
def skip_callback(module: nn.Module) -> bool:
616+
fastercache_state: FasterCacheBlockState = module._fastercache_state
617+
is_within_timestep_range = timestep_skip_range[0] < pipeline._current_timestep < timestep_skip_range[1]
618+
619+
if not is_within_timestep_range:
620+
# We are still not in the phase of inference where skipping attention is possible without minimal quality
621+
# loss, as described in the paper. So, the attention computation cannot be skipped
622+
return False
623+
624+
should_compute_attention = (
625+
fastercache_state.iteration > 0 and fastercache_state.iteration % block_skip_range == 0
626+
)
627+
return not should_compute_attention
628+
629+
logger.debug(f"Enabling FasterCache ({block_type}) for layer: {name}")
630+
module._fastercache_state = FasterCacheBlockState(skip_callback, config.attention_weight_callback)
631+
hook = FasterCacheBlockHook()
632+
add_hook_to_module(module, hook, append=True)
633+
634+
635635
# Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/fastercache_sample_latte.py#L127C1-L143C39
636636
@torch.no_grad()
637637
def _split_low_high_freq(x):

0 commit comments

Comments
 (0)