2121from transformers import AutoTokenizer , UMT5EncoderModel
2222
2323from ...callbacks import MultiPipelineCallbacks , PipelineCallback
24+ from ...guiders import ClassifierFreeGuidance , GuidanceMixin , _raise_guidance_deprecation_warning
2425from ...loaders import WanLoraLoaderMixin
2526from ...models import AutoencoderKLWan , WanTransformer3DModel
2627from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -380,6 +381,7 @@ def __call__(
380381 ] = None ,
381382 callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
382383 max_sequence_length : int = 512 ,
384+ guidance : Optional [GuidanceMixin ] = None ,
383385 ):
384386 r"""
385387 The call function to the pipeline for generation.
@@ -444,6 +446,10 @@ def __call__(
444446 indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
445447 """
446448
449+ _raise_guidance_deprecation_warning (guidance_scale = guidance_scale )
450+ if guidance is None :
451+ guidance = ClassifierFreeGuidance (guidance_scale = guidance_scale )
452+
447453 if isinstance (callback_on_step_end , (PipelineCallback , MultiPipelineCallbacks )):
448454 callback_on_step_end_tensor_inputs = callback_on_step_end .tensor_inputs
449455
@@ -519,37 +525,38 @@ def __call__(
519525 num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
520526 self ._num_timesteps = len (timesteps )
521527
528+ conds = [prompt_embeds , negative_prompt_embeds ]
529+ prompt_embeds , negative_prompt_embeds = [[c ] for c in conds ]
530+
522531 with self .progress_bar (total = num_inference_steps ) as progress_bar , self .transformer ._cache_context () as cc :
523532 for i , t in enumerate (timesteps ):
533+ self ._current_timestep = t
524534 if self .interrupt :
525535 continue
526536
527- self ._current_timestep = t
528- latent_model_input = latents .to (transformer_dtype )
529- timestep = t .expand (latents .shape [0 ])
530-
531- cc .mark_state ("cond" )
532- noise_pred = self .transformer (
533- hidden_states = latent_model_input ,
534- timestep = timestep ,
535- encoder_hidden_states = prompt_embeds ,
536- attention_kwargs = attention_kwargs ,
537- return_dict = False ,
538- )[0 ]
539-
540- if self .do_classifier_free_guidance :
541- cc .mark_state ("uncond" )
542- noise_uncond = self .transformer (
543- hidden_states = latent_model_input ,
537+ guidance .set_state (step = i , num_inference_steps = num_inference_steps , timestep = t )
538+ guidance .prepare_models (self .transformer )
539+ latents , prompt_embeds = guidance .prepare_inputs (
540+ latents , (prompt_embeds [0 ], negative_prompt_embeds [0 ])
541+ )
542+
543+ for batch_index , (latent , condition ) in enumerate (zip (latents , prompt_embeds )):
544+ cc .mark_state (f"batch_{ batch_index } " )
545+ latent = latent .to (transformer_dtype )
546+ timestep = t .expand (latent .shape [0 ])
547+ noise_pred = self .transformer (
548+ hidden_states = latent ,
544549 timestep = timestep ,
545- encoder_hidden_states = negative_prompt_embeds ,
550+ encoder_hidden_states = condition ,
546551 attention_kwargs = attention_kwargs ,
547552 return_dict = False ,
548553 )[0 ]
549- noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond )
554+ guidance . prepare_outputs (noise_pred )
550555
551- # compute the previous noisy sample x_t -> x_t-1
552- latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
556+ outputs = guidance .outputs
557+ noise_pred = guidance (** outputs )
558+ latents = self .scheduler .step (noise_pred , t , latents [0 ], return_dict = False )[0 ]
559+ guidance .cleanup_models (self .transformer )
553560
554561 if callback_on_step_end is not None :
555562 callback_kwargs = {}
@@ -558,8 +565,10 @@ def __call__(
558565 callback_outputs = callback_on_step_end (self , i , t , callback_kwargs )
559566
560567 latents = callback_outputs .pop ("latents" , latents )
561- prompt_embeds = callback_outputs .pop ("prompt_embeds" , prompt_embeds )
562- negative_prompt_embeds = callback_outputs .pop ("negative_prompt_embeds" , negative_prompt_embeds )
568+ prompt_embeds = [callback_outputs .pop ("prompt_embeds" , prompt_embeds [0 ])]
569+ negative_prompt_embeds = [
570+ callback_outputs .pop ("negative_prompt_embeds" , negative_prompt_embeds [0 ])
571+ ]
563572
564573 # call the callback, if provided
565574 if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
0 commit comments