Skip to content

Commit 74e34e5

Browse files
committed
guiders support for wan
1 parent 357f4f0 commit 74e34e5

File tree

4 files changed

+124
-25
lines changed

4 files changed

+124
-25
lines changed

src/diffusers/hooks/_helpers.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
)
3131
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
3232
from ..models.transformers.transformer_mochi import MochiTransformerBlock
33-
from ..models.transformers.transformer_wan import WanTransformerBlock
33+
from ..models.transformers.transformer_wan import WanAttnProcessor2_0, WanPAGAttnProcessor2_0, WanTransformerBlock
3434

3535

3636
@dataclass
@@ -186,6 +186,14 @@ def _register_guidance_metadata():
186186
),
187187
)
188188

189+
# Wan
190+
GuidanceMetadataRegistry.register(
191+
model_class=WanAttnProcessor2_0,
192+
metadata=GuidanceMetadata(
193+
perturbed_attention_guidance_processor_cls=WanPAGAttnProcessor2_0,
194+
),
195+
)
196+
189197

190198
# fmt: off
191199
def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs):

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,3 +467,85 @@ def forward(
467467
return (output,)
468468

469469
return Transformer2DModelOutput(sample=output)
470+
471+
472+
### ===== Custom attention processors for guidance methods ===== ###
473+
474+
475+
class WanPAGAttnProcessor2_0:
476+
def __init__(self):
477+
if not hasattr(F, "scaled_dot_product_attention"):
478+
raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
479+
480+
def __call__(
481+
self,
482+
attn: Attention,
483+
hidden_states: torch.Tensor,
484+
encoder_hidden_states: Optional[torch.Tensor] = None,
485+
attention_mask: Optional[torch.Tensor] = None,
486+
rotary_emb: Optional[torch.Tensor] = None,
487+
) -> torch.Tensor:
488+
is_encoder_hidden_states_provided = encoder_hidden_states is not None
489+
encoder_hidden_states_img = None
490+
if attn.add_k_proj is not None:
491+
encoder_hidden_states_img = encoder_hidden_states[:, :257]
492+
encoder_hidden_states = encoder_hidden_states[:, 257:]
493+
if encoder_hidden_states is None:
494+
encoder_hidden_states = hidden_states
495+
496+
query = attn.to_q(hidden_states)
497+
key = attn.to_k(encoder_hidden_states)
498+
value = attn.to_v(encoder_hidden_states)
499+
500+
if attn.norm_q is not None:
501+
query = attn.norm_q(query)
502+
if attn.norm_k is not None:
503+
key = attn.norm_k(key)
504+
505+
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
506+
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
507+
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
508+
509+
if rotary_emb is not None:
510+
511+
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
512+
x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
513+
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
514+
return x_out.type_as(hidden_states)
515+
516+
query = apply_rotary_emb(query, rotary_emb)
517+
key = apply_rotary_emb(key, rotary_emb)
518+
519+
# I2V task
520+
hidden_states_img = None
521+
if encoder_hidden_states_img is not None:
522+
key_img = attn.add_k_proj(encoder_hidden_states_img)
523+
key_img = attn.norm_added_k(key_img)
524+
value_img = attn.add_v_proj(encoder_hidden_states_img)
525+
526+
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
527+
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
528+
529+
hidden_states_img = F.scaled_dot_product_attention(
530+
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
531+
)
532+
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
533+
hidden_states_img = hidden_states_img.type_as(query)
534+
535+
if is_encoder_hidden_states_provided:
536+
hidden_states = F.scaled_dot_product_attention(
537+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
538+
)
539+
else:
540+
# Perturbed attention applied only when self-attention
541+
hidden_states = value
542+
543+
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
544+
hidden_states = hidden_states.type_as(query)
545+
546+
if hidden_states_img is not None:
547+
hidden_states = hidden_states + hidden_states_img
548+
549+
hidden_states = attn.to_out[0](hidden_states)
550+
hidden_states = attn.to_out[1](hidden_states)
551+
return hidden_states

src/diffusers/pipelines/cogview4/pipeline_cogview4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ def __call__(
617617
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
618618

619619
conds = [prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left]
620-
prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left = [[v] for v in conds]
620+
prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left = [[c] for c in conds]
621621

622622
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
623623
for i, t in enumerate(timesteps):

src/diffusers/pipelines/wan/pipeline_wan.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from transformers import AutoTokenizer, UMT5EncoderModel
2222

2323
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
24+
from ...guiders import ClassifierFreeGuidance, GuidanceMixin, _raise_guidance_deprecation_warning
2425
from ...loaders import WanLoraLoaderMixin
2526
from ...models import AutoencoderKLWan, WanTransformer3DModel
2627
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -380,6 +381,7 @@ def __call__(
380381
] = None,
381382
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
382383
max_sequence_length: int = 512,
384+
guidance: Optional[GuidanceMixin] = None,
383385
):
384386
r"""
385387
The call function to the pipeline for generation.
@@ -444,6 +446,10 @@ def __call__(
444446
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
445447
"""
446448

449+
_raise_guidance_deprecation_warning(guidance_scale=guidance_scale)
450+
if guidance is None:
451+
guidance = ClassifierFreeGuidance(guidance_scale=guidance_scale)
452+
447453
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
448454
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
449455

@@ -519,37 +525,38 @@ def __call__(
519525
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
520526
self._num_timesteps = len(timesteps)
521527

528+
conds = [prompt_embeds, negative_prompt_embeds]
529+
prompt_embeds, negative_prompt_embeds = [[c] for c in conds]
530+
522531
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
523532
for i, t in enumerate(timesteps):
533+
self._current_timestep = t
524534
if self.interrupt:
525535
continue
526536

527-
self._current_timestep = t
528-
latent_model_input = latents.to(transformer_dtype)
529-
timestep = t.expand(latents.shape[0])
530-
531-
cc.mark_state("cond")
532-
noise_pred = self.transformer(
533-
hidden_states=latent_model_input,
534-
timestep=timestep,
535-
encoder_hidden_states=prompt_embeds,
536-
attention_kwargs=attention_kwargs,
537-
return_dict=False,
538-
)[0]
539-
540-
if self.do_classifier_free_guidance:
541-
cc.mark_state("uncond")
542-
noise_uncond = self.transformer(
543-
hidden_states=latent_model_input,
537+
guidance.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
538+
guidance.prepare_models(self.transformer)
539+
latents, prompt_embeds = guidance.prepare_inputs(
540+
latents, (prompt_embeds[0], negative_prompt_embeds[0])
541+
)
542+
543+
for batch_index, (latent, condition) in enumerate(zip(latents, prompt_embeds)):
544+
cc.mark_state(f"batch_{batch_index}")
545+
latent = latent.to(transformer_dtype)
546+
timestep = t.expand(latent.shape[0])
547+
noise_pred = self.transformer(
548+
hidden_states=latent,
544549
timestep=timestep,
545-
encoder_hidden_states=negative_prompt_embeds,
550+
encoder_hidden_states=condition,
546551
attention_kwargs=attention_kwargs,
547552
return_dict=False,
548553
)[0]
549-
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
554+
guidance.prepare_outputs(noise_pred)
550555

551-
# compute the previous noisy sample x_t -> x_t-1
552-
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
556+
outputs = guidance.outputs
557+
noise_pred = guidance(**outputs)
558+
latents = self.scheduler.step(noise_pred, t, latents[0], return_dict=False)[0]
559+
guidance.cleanup_models(self.transformer)
553560

554561
if callback_on_step_end is not None:
555562
callback_kwargs = {}
@@ -558,8 +565,10 @@ def __call__(
558565
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
559566

560567
latents = callback_outputs.pop("latents", latents)
561-
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
562-
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
568+
prompt_embeds = [callback_outputs.pop("prompt_embeds", prompt_embeds[0])]
569+
negative_prompt_embeds = [
570+
callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds[0])
571+
]
563572

564573
# call the callback, if provided
565574
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

0 commit comments

Comments
 (0)