Skip to content

Commit 5f08ab5

Browse files
committed
test cfg
1 parent 2ee946f commit 5f08ab5

File tree

1 file changed

+145
-15
lines changed

1 file changed

+145
-15
lines changed

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 145 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -314,12 +314,17 @@ def encode_prompt(
314314
self,
315315
prompt: Union[str, List[str]],
316316
prompt_2: Union[str, List[str]],
317+
negative_prompt: Union[str, List[str]] = None,
318+
negative_prompt_2: Union[str, List[str]] = None,
317319
device: Optional[torch.device] = None,
318320
num_images_per_prompt: int = 1,
319321
prompt_embeds: Optional[torch.FloatTensor] = None,
320322
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
323+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
324+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
321325
max_sequence_length: int = 512,
322326
lora_scale: Optional[float] = None,
327+
do_true_cfg: bool = False,
323328
):
324329
r"""
325330
@@ -356,24 +361,59 @@ def encode_prompt(
356361
scale_lora_layers(self.text_encoder_2, lora_scale)
357362

358363
prompt = [prompt] if isinstance(prompt, str) else prompt
364+
batch_size = len(prompt)
365+
366+
if do_true_cfg and negative_prompt is not None:
367+
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
368+
negative_batch_size = len(negative_prompt)
369+
370+
if negative_batch_size != batch_size:
371+
raise ValueError(
372+
f"Negative prompt batch size ({negative_batch_size}) does not match prompt batch size ({batch_size})"
373+
)
374+
375+
# Concatenate prompts
376+
prompts = prompt + negative_prompt
377+
prompts_2 = (
378+
prompt_2 + negative_prompt_2 if prompt_2 is not None and negative_prompt_2 is not None else None
379+
)
380+
else:
381+
prompts = prompt
382+
prompts_2 = prompt_2
359383

360384
if prompt_embeds is None:
361-
prompt_2 = prompt_2 or prompt
362-
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
385+
if prompts_2 is None:
386+
prompts_2 = prompts
363387

364388
# We only use the pooled prompt output from the CLIPTextModel
365389
pooled_prompt_embeds = self._get_clip_prompt_embeds(
366-
prompt=prompt,
390+
prompt=prompts,
367391
device=device,
368392
num_images_per_prompt=num_images_per_prompt,
369393
)
370394
prompt_embeds = self._get_t5_prompt_embeds(
371-
prompt=prompt_2,
395+
prompt=prompts_2,
372396
num_images_per_prompt=num_images_per_prompt,
373397
max_sequence_length=max_sequence_length,
374398
device=device,
375399
)
376400

401+
if do_true_cfg and negative_prompt is not None:
402+
# Split embeddings back into positive and negative parts
403+
total_batch_size = batch_size * num_images_per_prompt
404+
positive_indices = slice(0, total_batch_size)
405+
negative_indices = slice(total_batch_size, 2 * total_batch_size)
406+
407+
positive_pooled_prompt_embeds = pooled_prompt_embeds[positive_indices]
408+
negative_pooled_prompt_embeds = pooled_prompt_embeds[negative_indices]
409+
410+
positive_prompt_embeds = prompt_embeds[positive_indices]
411+
negative_prompt_embeds = prompt_embeds[negative_indices]
412+
413+
pooled_prompt_embeds = positive_pooled_prompt_embeds
414+
prompt_embeds = positive_prompt_embeds
415+
416+
# Unscale LoRA layers
377417
if self.text_encoder is not None:
378418
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
379419
# Retrieve the original scale by scaling back the LoRA layers
@@ -387,7 +427,16 @@ def encode_prompt(
387427
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
388428
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
389429

390-
return prompt_embeds, pooled_prompt_embeds, text_ids
430+
if do_true_cfg and negative_prompt is not None:
431+
return (
432+
prompt_embeds,
433+
pooled_prompt_embeds,
434+
text_ids,
435+
negative_prompt_embeds,
436+
negative_pooled_prompt_embeds,
437+
)
438+
else:
439+
return prompt_embeds, pooled_prompt_embeds, text_ids, None, None
391440

392441
def encode_image(self, image, device, num_images_per_prompt):
393442
dtype = next(self.image_encoder.parameters()).dtype
@@ -439,8 +488,12 @@ def check_inputs(
439488
prompt_2,
440489
height,
441490
width,
491+
negative_prompt=None,
492+
negative_prompt_2=None,
442493
prompt_embeds=None,
494+
negative_prompt_embeds=None,
443495
pooled_prompt_embeds=None,
496+
negative_pooled_prompt_embeds=None,
444497
callback_on_step_end_tensor_inputs=None,
445498
max_sequence_length=None,
446499
):
@@ -475,10 +528,33 @@ def check_inputs(
475528
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
476529
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
477530

531+
if negative_prompt is not None and negative_prompt_embeds is not None:
532+
raise ValueError(
533+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
534+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
535+
)
536+
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
537+
raise ValueError(
538+
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
539+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
540+
)
541+
542+
if prompt_embeds is not None and negative_prompt_embeds is not None:
543+
if prompt_embeds.shape != negative_prompt_embeds.shape:
544+
raise ValueError(
545+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
546+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
547+
f" {negative_prompt_embeds.shape}."
548+
)
549+
478550
if prompt_embeds is not None and pooled_prompt_embeds is None:
479551
raise ValueError(
480552
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
481553
)
554+
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
555+
raise ValueError(
556+
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
557+
)
482558

483559
if max_sequence_length is not None and max_sequence_length > 512:
484560
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@@ -607,6 +683,9 @@ def __call__(
607683
self,
608684
prompt: Union[str, List[str]] = None,
609685
prompt_2: Optional[Union[str, List[str]]] = None,
686+
negative_prompt: Union[str, List[str]] = None,
687+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
688+
true_cfg: float = 1.0,
610689
height: Optional[int] = None,
611690
width: Optional[int] = None,
612691
num_inference_steps: int = 28,
@@ -619,6 +698,10 @@ def __call__(
619698
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
620699
ip_adapter_image: Optional[PipelineImageInput] = None,
621700
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
701+
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
702+
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
703+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
704+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
622705
output_type: Optional[str] = "pil",
623706
return_dict: bool = True,
624707
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -673,6 +756,11 @@ def __call__(
673756
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
674757
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
675758
provided, embeddings are computed from the `ip_adapter_image` input argument.
759+
negative_ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
760+
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
761+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
762+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
763+
provided, embeddings are computed from the `ip_adapter_image` input argument.
676764
output_type (`str`, *optional*, defaults to `"pil"`):
677765
The output format of the generate image. Choose between
678766
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -710,8 +798,12 @@ def __call__(
710798
prompt_2,
711799
height,
712800
width,
801+
negative_prompt=negative_prompt,
802+
negative_prompt_2=negative_prompt_2,
713803
prompt_embeds=prompt_embeds,
804+
negative_prompt_embeds=negative_prompt_embeds,
714805
pooled_prompt_embeds=pooled_prompt_embeds,
806+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
715807
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
716808
max_sequence_length=max_sequence_length,
717809
)
@@ -733,21 +825,34 @@ def __call__(
733825
lora_scale = (
734826
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
735827
)
828+
do_true_cfg = true_cfg > 1 and negative_prompt is not None
736829
(
737830
prompt_embeds,
738831
pooled_prompt_embeds,
739832
text_ids,
833+
negative_prompt_embeds,
834+
negative_pooled_prompt_embeds,
740835
) = self.encode_prompt(
741836
prompt=prompt,
742837
prompt_2=prompt_2,
838+
negative_prompt=negative_prompt,
839+
negative_prompt_2=negative_prompt_2,
743840
prompt_embeds=prompt_embeds,
744841
pooled_prompt_embeds=pooled_prompt_embeds,
842+
negative_prompt_embeds=negative_prompt_embeds,
843+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
745844
device=device,
746845
num_images_per_prompt=num_images_per_prompt,
747846
max_sequence_length=max_sequence_length,
748847
lora_scale=lora_scale,
848+
do_true_cfg=do_true_cfg,
749849
)
750850

851+
if do_true_cfg:
852+
# Concatenate embeddings
853+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
854+
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
855+
751856
# 4. Prepare latent variables
752857
num_channels_latents = self.transformer.config.in_channels // 4
753858
latents, latent_image_ids = self.prepare_latents(
@@ -781,12 +886,17 @@ def __call__(
781886
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
782887
self._num_timesteps = len(timesteps)
783888

784-
# handle guidance
785-
if self.transformer.config.guidance_embeds:
786-
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
787-
guidance = guidance.expand(latents.shape[0])
788-
else:
789-
guidance = None
889+
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
890+
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
891+
):
892+
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
893+
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
894+
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
895+
):
896+
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
897+
898+
if self.joint_attention_kwargs is None:
899+
self._joint_attention_kwargs = {}
790900

791901
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
792902
image_embeds = self.prepare_ip_adapter_image_embeds(
@@ -795,21 +905,37 @@ def __call__(
795905
device,
796906
batch_size * num_images_per_prompt,
797907
)
798-
if self.joint_attention_kwargs is None:
799-
self._joint_attention_kwargs = {}
800908
self._joint_attention_kwargs["image_projection"] = image_embeds
909+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
910+
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
911+
negative_ip_adapter_image,
912+
negative_ip_adapter_image_embeds,
913+
device,
914+
batch_size * num_images_per_prompt,
915+
)
916+
image_embeds = self._joint_attention_kwargs["image_projection"]
917+
self._joint_attention_kwargs["image_projection"] = torch.cat([negative_image_embeds, image_embeds])
801918

802919
# 6. Denoising loop
803920
with self.progress_bar(total=num_inference_steps) as progress_bar:
804921
for i, t in enumerate(timesteps):
805922
if self.interrupt:
806923
continue
807924

925+
latent_model_input = torch.cat([latents] * 2) if do_true_cfg else latents
926+
927+
# handle guidance
928+
if self.transformer.config.guidance_embeds:
929+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
930+
guidance = guidance.expand(latent_model_input.shape[0])
931+
else:
932+
guidance = None
933+
808934
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
809-
timestep = t.expand(latents.shape[0]).to(latents.dtype)
935+
timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
810936

811937
noise_pred = self.transformer(
812-
hidden_states=latents,
938+
hidden_states=latent_model_input,
813939
timestep=timestep / 1000,
814940
guidance=guidance,
815941
pooled_projections=pooled_prompt_embeds,
@@ -820,6 +946,10 @@ def __call__(
820946
return_dict=False,
821947
)[0]
822948

949+
if do_true_cfg:
950+
neg_noise_pred, noise_pred = noise_pred.chunk(2)
951+
noise_pred = neg_noise_pred + true_cfg * (noise_pred - neg_noise_pred)
952+
823953
# compute the previous noisy sample x_t -> x_t-1
824954
latents_dtype = latents.dtype
825955
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

0 commit comments

Comments
 (0)