@@ -225,186 +225,6 @@ def reset(self):
225225 self .is_guidance_distilled = None
226226
227227
228- def apply_fastercache (
229- pipeline : DiffusionPipeline ,
230- config : Optional [FasterCacheConfig ] = None ,
231- ) -> None :
232- r"""
233- Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
234-
235- Args:
236- pipeline (`DiffusionPipeline`):
237- The diffusion pipeline to apply FasterCache to.
238- config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`):
239- The configuration to use for FasterCache.
240-
241- Example:
242- ```python
243- >>> import torch
244- >>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_fastercache
245-
246- >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
247- >>> pipe.to("cuda")
248-
249- >>> config = FasterCacheConfig(
250- ... spatial_attention_block_skip_range=2,
251- ... spatial_attention_timestep_skip_range=(-1, 681),
252- ... low_frequency_weight_update_timestep_range=(99, 641),
253- ... high_frequency_weight_update_timestep_range=(-1, 301),
254- ... spatial_attention_block_identifiers=["transformer_blocks"],
255- ... attention_weight_callback=lambda _: 0.3,
256- ... tensor_format="BFCHW",
257- ... )
258- >>> apply_fastercache(pipe, config)
259- ```
260- """
261-
262- if config is None :
263- logger .warning ("No FasterCacheConfig provided. Using default configuration." )
264- config = FasterCacheConfig ()
265-
266- if config .attention_weight_callback is None :
267- # If the user has not provided a weight callback, we default to 0.5 for all timesteps.
268- # In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but
269- # this depends from model-to-model. It is required by the user to provide a weight callback if they want to
270- # use a different weight function. Defaulting to 0.5 works well in practice for most cases.
271- logger .warning (
272- "No `attention_weight_callback` provided when enabling FasterCache. Defaulting to using a weight of 0.5 for all timesteps."
273- )
274- config .attention_weight_callback = lambda _ : 0.5
275-
276- if config .low_frequency_weight_callback is None :
277- logger .debug (
278- "Low frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
279- )
280-
281- def low_frequency_weight_callback (module : nn .Module ) -> float :
282- is_within_range = (
283- config .low_frequency_weight_update_timestep_range [0 ]
284- < pipeline ._current_timestep
285- < config .low_frequency_weight_update_timestep_range [1 ]
286- )
287- return config .alpha_low_frequency if is_within_range else 1.0
288-
289- config .low_frequency_weight_callback = low_frequency_weight_callback
290-
291- if config .high_frequency_weight_callback is None :
292- logger .debug (
293- "High frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
294- )
295-
296- def high_frequency_weight_callback (module : nn .Module ) -> float :
297- is_within_range = (
298- config .high_frequency_weight_update_timestep_range [0 ]
299- < pipeline ._current_timestep
300- < config .high_frequency_weight_update_timestep_range [1 ]
301- )
302- return config .alpha_high_frequency if is_within_range else 1.0
303-
304- config .high_frequency_weight_callback = high_frequency_weight_callback
305-
306- supported_tensor_formats = ["BCFHW" , "BFCHW" , "BCHW" ] # TODO(aryan): Support BSC for LTX Video
307- if config .tensor_format not in supported_tensor_formats :
308- raise ValueError (f"`tensor_format` must be one of { supported_tensor_formats } , but got { config .tensor_format } ." )
309-
310- denoiser = pipeline .transformer if hasattr (pipeline , "transformer" ) else pipeline .unet
311- _apply_fastercache_on_denoiser (pipeline , denoiser , config )
312-
313- for name , module in denoiser .named_modules ():
314- if not isinstance (module , _ATTENTION_CLASSES ):
315- continue
316- if isinstance (module , Attention ):
317- _apply_fastercache_on_attention_class (pipeline , name , module , config )
318-
319-
320- def _apply_fastercache_on_denoiser (
321- pipeline : DiffusionPipeline , denoiser : nn .Module , config : FasterCacheConfig
322- ) -> None :
323- def uncond_skip_callback (module : nn .Module ) -> bool :
324- # We skip the unconditional branch only if the following conditions are met:
325- # 1. We have completed at least one iteration of the denoiser
326- # 2. The current timestep is within the range specified by the user. This is the optimal timestep range
327- # where approximating the unconditional branch from the computation of the conditional branch is possible
328- # without a significant loss in quality.
329- # 3. The current iteration is not a multiple of the unconditional batch skip range. This is done so that
330- # we compute the unconditional branch at least once every few iterations to ensure minimal quality loss.
331-
332- state : FasterCacheDenoiserState = module ._fastercache_state
333- is_within_range = (
334- config .unconditional_batch_timestep_skip_range [0 ]
335- < pipeline ._current_timestep
336- < config .unconditional_batch_timestep_skip_range [1 ]
337- )
338- return state .iteration > 0 and is_within_range and state .iteration % config .unconditional_batch_skip_range != 0
339-
340- denoiser ._fastercache_state = FasterCacheDenoiserState (
341- config .low_frequency_weight_callback , config .high_frequency_weight_callback , uncond_skip_callback
342- )
343- hook = FasterCacheDenoiserHook (
344- config ._unconditional_conditional_input_kwargs_identifiers ,
345- config ._guidance_distillation_kwargs_identifiers ,
346- config .tensor_format ,
347- )
348- add_hook_to_module (denoiser , hook , append = True )
349-
350-
351- def _apply_fastercache_on_attention_class (
352- pipeline : DiffusionPipeline , name : str , module : Attention , config : FasterCacheConfig
353- ) -> None :
354- is_spatial_self_attention = (
355- any (re .search (identifier , name ) is not None for identifier in config .spatial_attention_block_identifiers )
356- and config .spatial_attention_block_skip_range is not None
357- and not module .is_cross_attention
358- )
359- is_temporal_self_attention = (
360- any (
361- f"{ identifier } ." in name or identifier == name
362- for identifier in config .temporal_attention_block_identifiers
363- )
364- and config .temporal_attention_block_skip_range is not None
365- and not module .is_cross_attention
366- )
367-
368- block_skip_range , timestep_skip_range , block_type = None , None , None
369- if is_spatial_self_attention :
370- block_skip_range = config .spatial_attention_block_skip_range
371- timestep_skip_range = config .spatial_attention_timestep_skip_range
372- block_type = "spatial"
373- elif is_temporal_self_attention :
374- block_skip_range = config .temporal_attention_block_skip_range
375- timestep_skip_range = config .temporal_attention_timestep_skip_range
376- block_type = "temporal"
377-
378- if block_skip_range is None or timestep_skip_range is None :
379- logger .debug (
380- f'Unable to apply FasterCache to the selected layer: "{ name } " because it does '
381- f"not match any of the required criteria for spatial or temporal attention layers. Note, "
382- f"however, that this layer may still be valid for applying PAB. Please specify the correct "
383- f"block identifiers in the configuration or use the specialized `apply_fastercache_on_module` "
384- f"function to apply FasterCache to this layer."
385- )
386- return
387-
388- def skip_callback (module : nn .Module ) -> bool :
389- fastercache_state : FasterCacheBlockState = module ._fastercache_state
390- is_within_timestep_range = timestep_skip_range [0 ] < pipeline ._current_timestep < timestep_skip_range [1 ]
391-
392- if not is_within_timestep_range :
393- # We are still not in the phase of inference where skipping attention is possible without minimal quality
394- # loss, as described in the paper. So, the attention computation cannot be skipped
395- return False
396-
397- should_compute_attention = (
398- fastercache_state .iteration > 0 and fastercache_state .iteration % block_skip_range == 0
399- )
400- return not should_compute_attention
401-
402- logger .debug (f"Enabling FasterCache ({ block_type } ) for layer: { name } " )
403- module ._fastercache_state = FasterCacheBlockState (skip_callback , config .attention_weight_callback )
404- hook = FasterCacheBlockHook ()
405- add_hook_to_module (module , hook , append = True )
406-
407-
408228class FasterCacheDenoiserHook (ModelHook ):
409229 _is_stateful = True
410230
@@ -632,6 +452,186 @@ def reset_state(self, module: nn.Module) -> nn.Module:
632452 return module
633453
634454
455+ def apply_fastercache (
456+ pipeline : DiffusionPipeline ,
457+ config : Optional [FasterCacheConfig ] = None ,
458+ ) -> None :
459+ r"""
460+ Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
461+
462+ Args:
463+ pipeline (`DiffusionPipeline`):
464+ The diffusion pipeline to apply FasterCache to.
465+ config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`):
466+ The configuration to use for FasterCache.
467+
468+ Example:
469+ ```python
470+ >>> import torch
471+ >>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_fastercache
472+
473+ >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
474+ >>> pipe.to("cuda")
475+
476+ >>> config = FasterCacheConfig(
477+ ... spatial_attention_block_skip_range=2,
478+ ... spatial_attention_timestep_skip_range=(-1, 681),
479+ ... low_frequency_weight_update_timestep_range=(99, 641),
480+ ... high_frequency_weight_update_timestep_range=(-1, 301),
481+ ... spatial_attention_block_identifiers=["transformer_blocks"],
482+ ... attention_weight_callback=lambda _: 0.3,
483+ ... tensor_format="BFCHW",
484+ ... )
485+ >>> apply_fastercache(pipe, config)
486+ ```
487+ """
488+
489+ if config is None :
490+ logger .warning ("No FasterCacheConfig provided. Using default configuration." )
491+ config = FasterCacheConfig ()
492+
493+ if config .attention_weight_callback is None :
494+ # If the user has not provided a weight callback, we default to 0.5 for all timesteps.
495+ # In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but
496+ # this depends from model-to-model. It is required by the user to provide a weight callback if they want to
497+ # use a different weight function. Defaulting to 0.5 works well in practice for most cases.
498+ logger .warning (
499+ "No `attention_weight_callback` provided when enabling FasterCache. Defaulting to using a weight of 0.5 for all timesteps."
500+ )
501+ config .attention_weight_callback = lambda _ : 0.5
502+
503+ if config .low_frequency_weight_callback is None :
504+ logger .debug (
505+ "Low frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
506+ )
507+
508+ def low_frequency_weight_callback (module : nn .Module ) -> float :
509+ is_within_range = (
510+ config .low_frequency_weight_update_timestep_range [0 ]
511+ < pipeline ._current_timestep
512+ < config .low_frequency_weight_update_timestep_range [1 ]
513+ )
514+ return config .alpha_low_frequency if is_within_range else 1.0
515+
516+ config .low_frequency_weight_callback = low_frequency_weight_callback
517+
518+ if config .high_frequency_weight_callback is None :
519+ logger .debug (
520+ "High frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
521+ )
522+
523+ def high_frequency_weight_callback (module : nn .Module ) -> float :
524+ is_within_range = (
525+ config .high_frequency_weight_update_timestep_range [0 ]
526+ < pipeline ._current_timestep
527+ < config .high_frequency_weight_update_timestep_range [1 ]
528+ )
529+ return config .alpha_high_frequency if is_within_range else 1.0
530+
531+ config .high_frequency_weight_callback = high_frequency_weight_callback
532+
533+ supported_tensor_formats = ["BCFHW" , "BFCHW" , "BCHW" ] # TODO(aryan): Support BSC for LTX Video
534+ if config .tensor_format not in supported_tensor_formats :
535+ raise ValueError (f"`tensor_format` must be one of { supported_tensor_formats } , but got { config .tensor_format } ." )
536+
537+ denoiser = pipeline .transformer if hasattr (pipeline , "transformer" ) else pipeline .unet
538+ _apply_fastercache_on_denoiser (pipeline , denoiser , config )
539+
540+ for name , module in denoiser .named_modules ():
541+ if not isinstance (module , _ATTENTION_CLASSES ):
542+ continue
543+ if isinstance (module , Attention ):
544+ _apply_fastercache_on_attention_class (pipeline , name , module , config )
545+
546+
547+ def _apply_fastercache_on_denoiser (
548+ pipeline : DiffusionPipeline , denoiser : nn .Module , config : FasterCacheConfig
549+ ) -> None :
550+ def uncond_skip_callback (module : nn .Module ) -> bool :
551+ # We skip the unconditional branch only if the following conditions are met:
552+ # 1. We have completed at least one iteration of the denoiser
553+ # 2. The current timestep is within the range specified by the user. This is the optimal timestep range
554+ # where approximating the unconditional branch from the computation of the conditional branch is possible
555+ # without a significant loss in quality.
556+ # 3. The current iteration is not a multiple of the unconditional batch skip range. This is done so that
557+ # we compute the unconditional branch at least once every few iterations to ensure minimal quality loss.
558+
559+ state : FasterCacheDenoiserState = module ._fastercache_state
560+ is_within_range = (
561+ config .unconditional_batch_timestep_skip_range [0 ]
562+ < pipeline ._current_timestep
563+ < config .unconditional_batch_timestep_skip_range [1 ]
564+ )
565+ return state .iteration > 0 and is_within_range and state .iteration % config .unconditional_batch_skip_range != 0
566+
567+ denoiser ._fastercache_state = FasterCacheDenoiserState (
568+ config .low_frequency_weight_callback , config .high_frequency_weight_callback , uncond_skip_callback
569+ )
570+ hook = FasterCacheDenoiserHook (
571+ config ._unconditional_conditional_input_kwargs_identifiers ,
572+ config ._guidance_distillation_kwargs_identifiers ,
573+ config .tensor_format ,
574+ )
575+ add_hook_to_module (denoiser , hook , append = True )
576+
577+
578+ def _apply_fastercache_on_attention_class (
579+ pipeline : DiffusionPipeline , name : str , module : Attention , config : FasterCacheConfig
580+ ) -> None :
581+ is_spatial_self_attention = (
582+ any (re .search (identifier , name ) is not None for identifier in config .spatial_attention_block_identifiers )
583+ and config .spatial_attention_block_skip_range is not None
584+ and not module .is_cross_attention
585+ )
586+ is_temporal_self_attention = (
587+ any (
588+ f"{ identifier } ." in name or identifier == name
589+ for identifier in config .temporal_attention_block_identifiers
590+ )
591+ and config .temporal_attention_block_skip_range is not None
592+ and not module .is_cross_attention
593+ )
594+
595+ block_skip_range , timestep_skip_range , block_type = None , None , None
596+ if is_spatial_self_attention :
597+ block_skip_range = config .spatial_attention_block_skip_range
598+ timestep_skip_range = config .spatial_attention_timestep_skip_range
599+ block_type = "spatial"
600+ elif is_temporal_self_attention :
601+ block_skip_range = config .temporal_attention_block_skip_range
602+ timestep_skip_range = config .temporal_attention_timestep_skip_range
603+ block_type = "temporal"
604+
605+ if block_skip_range is None or timestep_skip_range is None :
606+ logger .debug (
607+ f'Unable to apply FasterCache to the selected layer: "{ name } " because it does '
608+ f"not match any of the required criteria for spatial or temporal attention layers. Note, "
609+ f"however, that this layer may still be valid for applying PAB. Please specify the correct "
610+ f"block identifiers in the configuration or use the specialized `apply_fastercache_on_module` "
611+ f"function to apply FasterCache to this layer."
612+ )
613+ return
614+
615+ def skip_callback (module : nn .Module ) -> bool :
616+ fastercache_state : FasterCacheBlockState = module ._fastercache_state
617+ is_within_timestep_range = timestep_skip_range [0 ] < pipeline ._current_timestep < timestep_skip_range [1 ]
618+
619+ if not is_within_timestep_range :
620+ # We are still not in the phase of inference where skipping attention is possible without minimal quality
621+ # loss, as described in the paper. So, the attention computation cannot be skipped
622+ return False
623+
624+ should_compute_attention = (
625+ fastercache_state .iteration > 0 and fastercache_state .iteration % block_skip_range == 0
626+ )
627+ return not should_compute_attention
628+
629+ logger .debug (f"Enabling FasterCache ({ block_type } ) for layer: { name } " )
630+ module ._fastercache_state = FasterCacheBlockState (skip_callback , config .attention_weight_callback )
631+ hook = FasterCacheBlockHook ()
632+ add_hook_to_module (module , hook , append = True )
633+
634+
635635# Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/fastercache_sample_latte.py#L127C1-L143C39
636636@torch .no_grad ()
637637def _split_low_high_freq (x ):
0 commit comments