Skip to content

Commit 5359222

Browse files
committed
update
1 parent 02a2e0d commit 5359222

File tree

5 files changed

+89
-61
lines changed

5 files changed

+89
-61
lines changed

src/diffusers/models/embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
334334
" `from_numpy` is no longer required."
335335
" Pass `output_type='pt' to use the new version now."
336336
)
337-
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
337+
# deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
338338
return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
339339
if embed_dim % 2 != 0:
340340
raise ValueError("embed_dim must be divisible by 2")

src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ...loaders import CogVideoXLoraLoaderMixin
2525
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
2626
from ...models.embeddings import get_3d_rotary_pos_embed
27+
from ...models.hooks import reset_stateful_hooks
2728
from ...pipelines.pipeline_utils import DiffusionPipeline
2829
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
2930
from ...utils import logging, replace_example_docstring
@@ -769,6 +770,7 @@ def __call__(
769770

770771
# Offload all models
771772
self.maybe_free_model_hooks()
773+
reset_stateful_hooks(self.transformer, recurse=True)
772774

773775
if not return_dict:
774776
return (video,)

src/diffusers/pipelines/faster_cache_utils.py

Lines changed: 70 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,12 @@
4949
class FasterCacheConfig:
5050
r"""
5151
Configuration for [FasterCache](https://huggingface.co/papers/2410.19355).
52-
"""
5352
54-
num_train_timesteps: int = 1000
53+
Attributes:"""
5554

5655
# In the paper and codebase, they hardcode these values to 2. However, it can be made configurable
5756
# after some testing. We default to 2 if these parameters are not provided.
58-
spatial_attention_block_skip_range: Optional[int] = None
57+
spatial_attention_block_skip_range: int = 2
5958
temporal_attention_block_skip_range: Optional[int] = None
6059

6160
# TODO(aryan): write heuristics for what the best way to obtain these values are
@@ -145,6 +144,9 @@ def apply_faster_cache(
145144
r"""
146145
Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
147146
147+
Note: FasterCache should only be applied when using classifer-free guidance. It will not work as expected even if
148+
the inference runs successfully.
149+
148150
Args:
149151
pipeline (`DiffusionPipeline`):
150152
The diffusion pipeline to apply FasterCache to.
@@ -163,15 +165,6 @@ def apply_faster_cache(
163165
if config is None:
164166
config = FasterCacheConfig()
165167

166-
if config.spatial_attention_block_skip_range is None and config.temporal_attention_block_skip_range is None:
167-
logger.warning(
168-
"FasterCache requires one of `spatial_attention_block_skip_range` and/or `temporal_attention_block_skip_range` "
169-
"to be set to an integer, not `None`. Defaulting to using `spatial_attention_block_skip_range=2` and "
170-
"`temporal_attention_block_skip_range=2`. To avoid this warning, please set one of the above parameters."
171-
)
172-
config.spatial_attention_block_skip_range = 2
173-
config.temporal_attention_block_skip_range = 2
174-
175168
if config.attention_weight_callback is None:
176169
# If the user has not provided a weight callback, we default to 0.5 for all timesteps.
177170
# In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but
@@ -231,12 +224,6 @@ def _apply_fastercache_on_denoiser(
231224
pipeline: DiffusionPipeline, denoiser: nn.Module, config: FasterCacheConfig
232225
) -> None:
233226
def uncond_skip_callback(module: nn.Module) -> bool:
234-
# If we are not using classifier-free guidance, we cannot skip the denoiser computation. We only compute the
235-
# conditional branch in this case.
236-
is_using_classifier_free_guidance = pipeline.do_classifier_free_guidance
237-
if not is_using_classifier_free_guidance:
238-
return False
239-
240227
# We skip the unconditional branch only if the following conditions are met:
241228
# 1. We have completed at least one iteration of the denoiser
242229
# 2. The current timestep is within the range specified by the user. This is the optimal timestep range
@@ -298,20 +285,13 @@ def _apply_fastercache_on_attention_class(
298285
return
299286

300287
def skip_callback(module: nn.Module) -> bool:
301-
is_using_classifier_free_guidance = pipeline.do_classifier_free_guidance
302-
if not is_using_classifier_free_guidance:
303-
return False
304-
305288
fastercache_state: FasterCacheState = module._fastercache_state
306289
is_within_timestep_range = timestep_skip_range[0] < pipeline._current_timestep < timestep_skip_range[1]
307290

308291
if not is_within_timestep_range:
309292
# We are still not in the phase of inference where skipping attention is possible without minimal quality
310293
# loss, as described in the paper. So, the attention computation cannot be skipped
311294
return False
312-
if fastercache_state.cache is None or fastercache_state.iteration < 2:
313-
# We need at least 2 iterations to start skipping attention computation
314-
return False
315295

316296
should_compute_attention = (
317297
fastercache_state.iteration > 0 and fastercache_state.iteration % block_skip_range == 0
@@ -358,8 +338,6 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any:
358338
# TODO(aryan): remove later
359339
logger.debug("Skipping unconditional branch computation")
360340

361-
if should_skip_uncond:
362-
breakpoint()
363341
output = module._old_forward(*args, **kwargs)
364342
# TODO(aryan): handle Transformer2DModelOutput
365343
hidden_states = output[0] if isinstance(output, tuple) else output
@@ -422,6 +400,22 @@ def reset_state(self, module: nn.Module) -> None:
422400
class FasterCacheBlockHook(ModelHook):
423401
_is_stateful = True
424402

403+
def _compute_approximated_attention_output(
404+
self, t_2_output: torch.Tensor, t_output: torch.Tensor, weight: float, batch_size: int
405+
) -> torch.Tensor:
406+
# TODO(aryan): these conditions may not be needed after latest refactor. they exist for safety. do test if they can be removed
407+
if t_2_output.size(0) != batch_size:
408+
# The cache t_2_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
409+
# take the conditional branch outputs.
410+
assert t_2_output.size(0) == 2 * batch_size
411+
t_2_output = t_2_output[batch_size:]
412+
if t_output.size(0) != batch_size:
413+
# The cache t_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
414+
# take the conditional branch outputs.
415+
assert t_output.size(0) == 2 * batch_size
416+
t_output = t_output[batch_size:]
417+
return t_output + (t_output - t_2_output) * weight
418+
425419
def new_forward(self, module: nn.Module, *args, **kwargs) -> Any:
426420
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
427421
state: FasterCacheState = module._fastercache_state
@@ -435,40 +429,59 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any:
435429
state.batch_size = batch_size
436430

437431
# If we have to skip due to the skip conditions, then let's skip as expected.
438-
# But, we can't skip if the denoiser wants to infer both unconditional and conditional branches. So,
439-
# if state.batch_size (which is the true unconditional-conditional batch size) is same as the current
440-
# batch size, we don't perform the layer skip. Otherwise, we conditionally skip the layer based on
441-
# what state.skip_callback returns.
442-
if state.skip_callback(module) and state.batch_size != batch_size:
432+
# But, we can't skip if the denoiser wants to infer both unconditional and conditional branches. This
433+
# is because the expected output shapes of attention layer will not match if we only return values from
434+
# the cache (which only caches conditional branch outputs). So, if state.batch_size (which is the true
435+
# unconditional-conditional batch size) is same as the current batch size, we don't perform the layer
436+
# skip. Otherwise, we conditionally skip the layer based on what state.skip_callback returns.
437+
should_skip_attention = state.skip_callback(module) and state.batch_size != batch_size
438+
439+
if should_skip_attention:
443440
# TODO(aryan): remove later
444-
logger.debug("Skipping layer computation")
445-
t_2_output, t_output = state.cache
446-
447-
# TODO(aryan): these conditions may not be needed after latest refactor. they exist for safety. do test if they can be removed
448-
if t_2_output.size(0) != batch_size:
449-
# The cache t_2_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
450-
# take the conditional branch outputs.
451-
assert t_2_output.size(0) == 2 * batch_size
452-
t_2_output = t_2_output[batch_size:]
453-
if t_output.size(0) != batch_size:
454-
# The cache t_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
455-
# take the conditional branch outputs.
456-
assert t_output.size(0) == 2 * batch_size
457-
t_output = t_output[batch_size:]
458-
459-
output = t_output + (t_output - t_2_output) * state.weight_callback(module)
441+
logger.debug("Skipping attention")
442+
443+
if torch.is_tensor(state.cache):
444+
t_2_output, t_output = state.cache
445+
weight = state.weight_callback(module)
446+
output = self._compute_approximated_attention_output(t_2_output, t_output, weight, batch_size)
447+
else:
448+
# The cache contains multiple tensors from past N iterations (N=2 for FasterCache). We need to handle all of them.
449+
# Diffusers blocks can return multiple tensors - let's call them [A, B, C, ...] for simplicity.
450+
# In our cache, we would have [[A_1, B_1, C_1, ...], [A_2, B_2, C_2, ...], ...] where each list is the output from
451+
# a forward pass of the block. We need to compute the approximated output for each of these tensors.
452+
# The zip(*state.cache) operation will give us [(A_1, A_2, ...), (B_1, B_2, ...), (C_1, C_2, ...), ...] which
453+
# allows us to compute the approximated attention output for each tensor in the cache.
454+
output = ()
455+
for t_2_output, t_output in zip(*state.cache):
456+
result = self._compute_approximated_attention_output(
457+
t_2_output, t_output, state.weight_callback(module), batch_size
458+
)
459+
output += (result,)
460460
else:
461+
logger.debug("Computing attention")
461462
output = module._old_forward(*args, **kwargs)
462463

463-
# The output here can be both unconditional-conditional branch outputs or just conditional branch outputs.
464-
# This is determined at the higher-level denoiser module. We only want to cache the conditional branch outputs.
465-
cache_output = output
466-
if output.size(0) == state.batch_size:
467-
cache_output = cache_output.chunk(2, dim=0)[1]
468-
469-
# Just to be safe that the output is of the correct size for both unconditional-conditional branch inference
470-
# and only-conditional branch inference.
471-
assert 2 * cache_output.size(0) == state.batch_size
464+
# Note that the following condition for getting hidden_states should suffice since Diffusers blocks either return
465+
# a single hidden_states tensor, or a tuple of (hidden_states, encoder_hidden_states) tensors. We need to handle
466+
# both cases.
467+
if torch.is_tensor(output):
468+
cache_output = output
469+
if cache_output.size(0) == state.batch_size:
470+
# The output here can be both unconditional-conditional branch outputs or just conditional branch outputs.
471+
# This is determined at the higher-level denoiser module. We only want to cache the conditional branch outputs.
472+
cache_output = cache_output.chunk(2, dim=0)[1]
473+
474+
# Just to be safe that the output is of the correct size for both unconditional-conditional branch inference
475+
# and only-conditional branch inference.
476+
assert 2 * cache_output.size(0) == state.batch_size
477+
else:
478+
# Cache all return values and perform the same operation as above
479+
cache_output = ()
480+
for out in output:
481+
if out.size(0) == state.batch_size:
482+
out = out.chunk(2, dim=0)[1]
483+
assert 2 * out.size(0) == state.batch_size
484+
cache_output += (out,)
472485

473486
if state.cache is None:
474487
state.cache = [cache_output, cache_output]

src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2323
from ...loaders import HunyuanVideoLoraLoaderMixin
2424
from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
25+
from ...models.hooks import reset_stateful_hooks
2526
from ...schedulers import FlowMatchEulerDiscreteScheduler
2627
from ...utils import logging, replace_example_docstring
2728
from ...utils.torch_utils import randn_tensor
@@ -573,6 +574,7 @@ def __call__(
573574

574575
self._guidance_scale = guidance_scale
575576
self._attention_kwargs = attention_kwargs
577+
self._current_timestep = None
576578
self._interrupt = False
577579

578580
device = self._execution_device
@@ -640,6 +642,7 @@ def __call__(
640642
if self.interrupt:
641643
continue
642644

645+
self._current_timestep = t
643646
latent_model_input = latents.to(transformer_dtype)
644647
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
645648
timestep = t.expand(latents.shape[0]).to(latents.dtype)
@@ -671,6 +674,8 @@ def __call__(
671674
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
672675
progress_bar.update()
673676

677+
self._current_timestep = None
678+
674679
if not output_type == "latent":
675680
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
676681
video = self.vae.decode(latents, return_dict=False)[0]
@@ -680,6 +685,7 @@ def __call__(
680685

681686
# Offload all models
682687
self.maybe_free_model_hooks()
688+
reset_stateful_hooks(self.transformer, recurse=True)
683689

684690
if not return_dict:
685691
return (video,)

src/diffusers/pipelines/mochi/pipeline_mochi.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121

2222
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2323
from ...loaders import Mochi1LoraLoaderMixin
24-
from ...models.autoencoders import AutoencoderKL
25-
from ...models.transformers import MochiTransformer3DModel
24+
from ...models import AutoencoderKLHunyuanVideo, MochiTransformer3DModel
25+
from ...models.hooks import reset_stateful_hooks
2626
from ...schedulers import FlowMatchEulerDiscreteScheduler
2727
from ...utils import (
2828
is_torch_xla_available,
@@ -184,7 +184,7 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
184184
def __init__(
185185
self,
186186
scheduler: FlowMatchEulerDiscreteScheduler,
187-
vae: AutoencoderKL,
187+
vae: AutoencoderKLHunyuanVideo,
188188
text_encoder: T5EncoderModel,
189189
tokenizer: T5TokenizerFast,
190190
transformer: MochiTransformer3DModel,
@@ -604,6 +604,7 @@ def __call__(
604604

605605
self._guidance_scale = guidance_scale
606606
self._attention_kwargs = attention_kwargs
607+
self._current_timestep = None
607608
self._interrupt = False
608609

609610
# 2. Define call parameters
@@ -673,6 +674,9 @@ def __call__(
673674
if self.interrupt:
674675
continue
675676

677+
# Note: Mochi uses reversed timesteps. To ensure compatibility with methods like FasterCache, we need
678+
# to make sure we're using the correct non-reversed timestep values.
679+
self._current_timestep = 1000 - t
676680
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
677681
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
678682
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
@@ -718,6 +722,8 @@ def __call__(
718722
if XLA_AVAILABLE:
719723
xm.mark_step()
720724

725+
self._current_timestep = None
726+
721727
if output_type == "latent":
722728
video = latents
723729
else:
@@ -741,6 +747,7 @@ def __call__(
741747

742748
# Offload all models
743749
self.maybe_free_model_hooks()
750+
reset_stateful_hooks(self.transformer, recurse=True)
744751

745752
if not return_dict:
746753
return (video,)

0 commit comments

Comments
 (0)