Skip to content

Commit d2f3a86

Browse files
committed
refactor pipeline
1 parent c539784 commit d2f3a86

File tree

2 files changed

+51
-40
lines changed

2 files changed

+51
-40
lines changed

src/diffusers/pipelines/wan/pipeline_wan.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def encode_prompt(
199199
self,
200200
prompt: Union[str, List[str]],
201201
negative_prompt: Optional[Union[str, List[str]]] = None,
202+
do_classifier_free_guidance: bool = False,
202203
num_videos_per_prompt: int = 1,
203204
prompt_embeds: Optional[torch.Tensor] = None,
204205
negative_prompt_embeds: Optional[torch.Tensor] = None,
@@ -216,6 +217,8 @@ def encode_prompt(
216217
The prompt or prompts not to guide the image generation. If not defined, one has to pass
217218
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
218219
less than `1`).
220+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
221+
Whether to use classifier free guidance or not.
219222
num_videos_per_prompt (`int`, *optional*, defaults to 1):
220223
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
221224
prompt_embeds (`torch.Tensor`, *optional*):
@@ -247,7 +250,7 @@ def encode_prompt(
247250
dtype=dtype,
248251
)
249252

250-
if negative_prompt_embeds is None:
253+
if do_classifier_free_guidance and negative_prompt_embeds is None:
251254
negative_prompt = negative_prompt or ""
252255
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
253256

@@ -348,6 +351,10 @@ def prepare_latents(
348351
def guidance_scale(self):
349352
return self._guidance_scale
350353

354+
@property
355+
def do_classifier_free_guidance(self):
356+
return self._guidance_scale > 1.0
357+
351358
@property
352359
def num_timesteps(self):
353360
return self._num_timesteps
@@ -377,7 +384,6 @@ def __call__(
377384
latents: Optional[torch.Tensor] = None,
378385
prompt_embeds: Optional[torch.Tensor] = None,
379386
negative_prompt_embeds: Optional[torch.Tensor] = None,
380-
prompt_attention_mask: Optional[torch.Tensor] = None,
381387
output_type: Optional[str] = "np",
382388
return_dict: bool = True,
383389
callback_on_step_end: Optional[
@@ -477,6 +483,7 @@ def __call__(
477483
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
478484
prompt=prompt,
479485
negative_prompt=negative_prompt,
486+
do_classifier_free_guidance=self.do_classifier_free_guidance,
480487
num_videos_per_prompt=num_videos_per_prompt,
481488
prompt_embeds=prompt_embeds,
482489
negative_prompt_embeds=negative_prompt_embeds,
@@ -486,7 +493,8 @@ def __call__(
486493

487494
transformer_dtype = self.transformer.dtype
488495
prompt_embeds = prompt_embeds.to(transformer_dtype)
489-
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
496+
if negative_prompt_embeds is not None:
497+
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
490498

491499
# 4. Prepare timesteps
492500
self.scheduler.flow_shift = flow_shift
@@ -523,22 +531,22 @@ def __call__(
523531
latent_model_input = latents.to(transformer_dtype)
524532
timestep = t.expand(latents.shape[0]).to(transformer_dtype)
525533

526-
noise_pred = self.transformer(
534+
noise_cond = self.transformer(
527535
hidden_states=latent_model_input,
528536
timestep=timestep,
529537
encoder_hidden_states=prompt_embeds,
530538
return_dict=False,
531539
)[0]
532540

533-
noise_pred_negative = self.transformer(
534-
hidden_states=latent_model_input,
535-
timestep=timestep,
536-
encoder_hidden_states=negative_prompt_embeds,
537-
return_dict=False,
538-
)[0]
539-
540-
noise_pred = noise_pred_negative + guidance_scale * (noise_pred - noise_pred_negative)
541-
541+
if self.do_classifier_free_guidance:
542+
noise_uncond = self.transformer(
543+
hidden_states=latent_model_input,
544+
timestep=timestep,
545+
encoder_hidden_states=negative_prompt_embeds,
546+
return_dict=False,
547+
)[0]
548+
noise_pred = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
549+
542550
# compute the previous noisy sample x_t -> x_t-1
543551
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
544552

src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ def encode_prompt(
242242
self,
243243
prompt: Union[str, List[str]],
244244
negative_prompt: Optional[Union[str, List[str]]] = None,
245+
do_classifier_free_guidance: bool = False,
245246
num_videos_per_prompt: int = 1,
246247
prompt_embeds: Optional[torch.Tensor] = None,
247248
negative_prompt_embeds: Optional[torch.Tensor] = None,
@@ -259,6 +260,8 @@ def encode_prompt(
259260
The prompt or prompts not to guide the image generation. If not defined, one has to pass
260261
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
261262
less than `1`).
263+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
264+
Whether to use classifier free guidance or not.
262265
num_videos_per_prompt (`int`, *optional*, defaults to 1):
263266
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
264267
prompt_embeds (`torch.Tensor`, *optional*):
@@ -290,7 +293,7 @@ def encode_prompt(
290293
dtype=dtype,
291294
)
292295

293-
if negative_prompt_embeds is None:
296+
if do_classifier_free_guidance and negative_prompt_embeds is None:
294297
negative_prompt = negative_prompt or ""
295298
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
296299

@@ -439,6 +442,10 @@ def prepare_latents(
439442
def guidance_scale(self):
440443
return self._guidance_scale
441444

445+
@property
446+
def do_classifier_free_guidance(self):
447+
return self._guidance_scale > 1
448+
442449
@property
443450
def num_timesteps(self):
444451
return self._num_timesteps
@@ -468,15 +475,13 @@ def __call__(
468475
latents: Optional[torch.Tensor] = None,
469476
prompt_embeds: Optional[torch.Tensor] = None,
470477
negative_prompt_embeds: Optional[torch.Tensor] = None,
471-
prompt_attention_mask: Optional[torch.Tensor] = None,
472478
output_type: Optional[str] = "np",
473479
return_dict: bool = True,
474480
callback_on_step_end: Optional[
475481
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
476482
] = None,
477483
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
478484
max_sequence_length: int = 512,
479-
autocast_dtype: torch.dtype = torch.bfloat16,
480485
):
481486
r"""
482487
The call function to the pipeline for generation.
@@ -571,20 +576,22 @@ def __call__(
571576
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
572577
prompt=prompt,
573578
negative_prompt=negative_prompt,
579+
do_classifier_free_guidance=self.do_classifier_free_guidance,
574580
num_videos_per_prompt=num_videos_per_prompt,
575581
prompt_embeds=prompt_embeds,
576582
negative_prompt_embeds=negative_prompt_embeds,
577583
max_sequence_length=max_sequence_length,
578584
device=device,
579-
dtype=autocast_dtype,
580585
)
581-
# encode image embedding
586+
587+
# Encode image embedding
582588
image_embeds = self.encode_image(image)
583589
image_embeds = image_embeds.repeat(batch_size, 1, 1)
584590

585-
prompt_embeds = prompt_embeds.to(autocast_dtype)
586-
negative_prompt_embeds = negative_prompt_embeds.to(autocast_dtype)
587-
image_embeds = image_embeds.to(autocast_dtype)
591+
transformer_dtype = self.transformer.dtype
592+
prompt_embeds = prompt_embeds.to(transformer_dtype)
593+
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
594+
image_embeds = image_embeds.to(transformer_dtype)
588595

589596
# 4. Prepare timesteps
590597
self.scheduler.flow_shift = flow_shift
@@ -596,6 +603,7 @@ def __call__(
596603
height, width = image.shape[-2:]
597604
else:
598605
width, height = image.size
606+
599607
# 5. Prepare latent variables
600608
num_channels_latents = self.vae.config.z_dim
601609
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
@@ -618,37 +626,32 @@ def __call__(
618626
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
619627
self._num_timesteps = len(timesteps)
620628

621-
with (
622-
self.progress_bar(total=num_inference_steps) as progress_bar,
623-
amp.autocast('cuda', dtype=autocast_dtype, cache_enabled=False)
624-
):
629+
with self.progress_bar(total=num_inference_steps) as progress_bar:
625630
for i, t in enumerate(timesteps):
626631
if self.interrupt:
627632
continue
628633

629634
self._current_timestep = t
630-
latent_model_input = latents
631-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
632-
timestep = t.expand(latents.shape[0])
635+
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
636+
timestep = t.expand(latents.shape[0]).to(transformer_dtype)
633637

634-
noise_pred = self.transformer(
635-
hidden_states=torch.concat([latent_model_input, condition], dim=1),
638+
noise_cond = self.transformer(
639+
hidden_states=latent_model_input,
636640
timestep=timestep,
637641
encoder_hidden_states=prompt_embeds,
638642
encoder_hidden_states_image=image_embeds,
639643
return_dict=False,
640644
)[0]
641645

642-
noise_pred_negative = self.transformer(
643-
hidden_states=torch.concat([latent_model_input, condition], dim=1),
644-
timestep=timestep,
645-
encoder_hidden_states=negative_prompt_embeds,
646-
encoder_hidden_states_image=image_embeds,
647-
return_dict=False,
648-
)[0]
649-
650-
noise_pred = noise_pred_negative + guidance_scale * (
651-
noise_pred - noise_pred_negative)
646+
if self.do_classifier_free_guidance:
647+
noise_uncond = self.transformer(
648+
hidden_states=latent_model_input,
649+
timestep=timestep,
650+
encoder_hidden_states=negative_prompt_embeds,
651+
encoder_hidden_states_image=image_embeds,
652+
return_dict=False,
653+
)[0]
654+
noise_pred = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
652655

653656
# compute the previous noisy sample x_t -> x_t-1
654657
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

0 commit comments

Comments
 (0)