Skip to content

Commit d794ab5

Browse files
committed
test cfg
1 parent 2ee946f commit d794ab5

File tree

3 files changed

+146
-12
lines changed

3 files changed

+146
-12
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2530,8 +2530,6 @@ def __call__(
25302530
image_projection: Optional[List[torch.Tensor]] = None,
25312531
ip_adapter_masks: Optional[torch.Tensor] = None,
25322532
) -> torch.FloatTensor:
2533-
if image_projection is None:
2534-
raise ValueError("image_projection is None")
25352533
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
25362534

25372535
# `sample` projections.
@@ -2606,8 +2604,7 @@ def __call__(
26062604
ip_query = hidden_states_query_proj
26072605
ip_attn_output = None
26082606
# for ip-adapter
2609-
# TODO: fix for multiple
2610-
# NOTE: run zeros image embed at the same time?
2607+
# TODO: support for multiple adapters
26112608
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
26122609
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
26132610
):

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 143 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ class FluxPipeline(
178178
"""
179179

180180
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
181-
_optional_components = []
181+
_optional_components = ["image_encoder", "feature_extractor"]
182182
_callback_tensor_inputs = ["latents", "prompt_embeds"]
183183

184184
def __init__(
@@ -314,12 +314,17 @@ def encode_prompt(
314314
self,
315315
prompt: Union[str, List[str]],
316316
prompt_2: Union[str, List[str]],
317+
negative_prompt: Union[str, List[str]] = None,
318+
negative_prompt_2: Union[str, List[str]] = None,
317319
device: Optional[torch.device] = None,
318320
num_images_per_prompt: int = 1,
319321
prompt_embeds: Optional[torch.FloatTensor] = None,
320322
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
323+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
324+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
321325
max_sequence_length: int = 512,
322326
lora_scale: Optional[float] = None,
327+
do_true_cfg: bool = False,
323328
):
324329
r"""
325330
@@ -356,24 +361,59 @@ def encode_prompt(
356361
scale_lora_layers(self.text_encoder_2, lora_scale)
357362

358363
prompt = [prompt] if isinstance(prompt, str) else prompt
364+
batch_size = len(prompt)
365+
366+
if do_true_cfg and negative_prompt is not None:
367+
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
368+
negative_batch_size = len(negative_prompt)
369+
370+
if negative_batch_size != batch_size:
371+
raise ValueError(
372+
f"Negative prompt batch size ({negative_batch_size}) does not match prompt batch size ({batch_size})"
373+
)
374+
375+
# Concatenate prompts
376+
prompts = prompt + negative_prompt
377+
prompts_2 = (
378+
prompt_2 + negative_prompt_2 if prompt_2 is not None and negative_prompt_2 is not None else None
379+
)
380+
else:
381+
prompts = prompt
382+
prompts_2 = prompt_2
359383

360384
if prompt_embeds is None:
361-
prompt_2 = prompt_2 or prompt
362-
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
385+
if prompts_2 is None:
386+
prompts_2 = prompts
363387

364388
# We only use the pooled prompt output from the CLIPTextModel
365389
pooled_prompt_embeds = self._get_clip_prompt_embeds(
366-
prompt=prompt,
390+
prompt=prompts,
367391
device=device,
368392
num_images_per_prompt=num_images_per_prompt,
369393
)
370394
prompt_embeds = self._get_t5_prompt_embeds(
371-
prompt=prompt_2,
395+
prompt=prompts_2,
372396
num_images_per_prompt=num_images_per_prompt,
373397
max_sequence_length=max_sequence_length,
374398
device=device,
375399
)
376400

401+
if do_true_cfg and negative_prompt is not None:
402+
# Split embeddings back into positive and negative parts
403+
total_batch_size = batch_size * num_images_per_prompt
404+
positive_indices = slice(0, total_batch_size)
405+
negative_indices = slice(total_batch_size, 2 * total_batch_size)
406+
407+
positive_pooled_prompt_embeds = pooled_prompt_embeds[positive_indices]
408+
negative_pooled_prompt_embeds = pooled_prompt_embeds[negative_indices]
409+
410+
positive_prompt_embeds = prompt_embeds[positive_indices]
411+
negative_prompt_embeds = prompt_embeds[negative_indices]
412+
413+
pooled_prompt_embeds = positive_pooled_prompt_embeds
414+
prompt_embeds = positive_prompt_embeds
415+
416+
# Unscale LoRA layers
377417
if self.text_encoder is not None:
378418
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
379419
# Retrieve the original scale by scaling back the LoRA layers
@@ -387,7 +427,16 @@ def encode_prompt(
387427
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
388428
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
389429

390-
return prompt_embeds, pooled_prompt_embeds, text_ids
430+
if do_true_cfg and negative_prompt is not None:
431+
return (
432+
prompt_embeds,
433+
pooled_prompt_embeds,
434+
text_ids,
435+
negative_prompt_embeds,
436+
negative_pooled_prompt_embeds,
437+
)
438+
else:
439+
return prompt_embeds, pooled_prompt_embeds, text_ids, None, None
391440

392441
def encode_image(self, image, device, num_images_per_prompt):
393442
dtype = next(self.image_encoder.parameters()).dtype
@@ -439,8 +488,12 @@ def check_inputs(
439488
prompt_2,
440489
height,
441490
width,
491+
negative_prompt=None,
492+
negative_prompt_2=None,
442493
prompt_embeds=None,
494+
negative_prompt_embeds=None,
443495
pooled_prompt_embeds=None,
496+
negative_pooled_prompt_embeds=None,
444497
callback_on_step_end_tensor_inputs=None,
445498
max_sequence_length=None,
446499
):
@@ -475,10 +528,33 @@ def check_inputs(
475528
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
476529
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
477530

531+
if negative_prompt is not None and negative_prompt_embeds is not None:
532+
raise ValueError(
533+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
534+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
535+
)
536+
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
537+
raise ValueError(
538+
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
539+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
540+
)
541+
542+
if prompt_embeds is not None and negative_prompt_embeds is not None:
543+
if prompt_embeds.shape != negative_prompt_embeds.shape:
544+
raise ValueError(
545+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
546+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
547+
f" {negative_prompt_embeds.shape}."
548+
)
549+
478550
if prompt_embeds is not None and pooled_prompt_embeds is None:
479551
raise ValueError(
480552
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
481553
)
554+
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
555+
raise ValueError(
556+
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
557+
)
482558

483559
if max_sequence_length is not None and max_sequence_length > 512:
484560
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@@ -607,6 +683,9 @@ def __call__(
607683
self,
608684
prompt: Union[str, List[str]] = None,
609685
prompt_2: Optional[Union[str, List[str]]] = None,
686+
negative_prompt: Union[str, List[str]] = None,
687+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
688+
true_cfg: float = 1.0,
610689
height: Optional[int] = None,
611690
width: Optional[int] = None,
612691
num_inference_steps: int = 28,
@@ -619,6 +698,10 @@ def __call__(
619698
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
620699
ip_adapter_image: Optional[PipelineImageInput] = None,
621700
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
701+
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
702+
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
703+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
704+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
622705
output_type: Optional[str] = "pil",
623706
return_dict: bool = True,
624707
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -673,6 +756,11 @@ def __call__(
673756
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
674757
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
675758
provided, embeddings are computed from the `ip_adapter_image` input argument.
759+
negative_ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
760+
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
761+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
762+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
763+
provided, embeddings are computed from the `ip_adapter_image` input argument.
676764
output_type (`str`, *optional*, defaults to `"pil"`):
677765
The output format of the generate image. Choose between
678766
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -710,8 +798,12 @@ def __call__(
710798
prompt_2,
711799
height,
712800
width,
801+
negative_prompt=negative_prompt,
802+
negative_prompt_2=negative_prompt_2,
713803
prompt_embeds=prompt_embeds,
804+
negative_prompt_embeds=negative_prompt_embeds,
714805
pooled_prompt_embeds=pooled_prompt_embeds,
806+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
715807
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
716808
max_sequence_length=max_sequence_length,
717809
)
@@ -733,19 +825,27 @@ def __call__(
733825
lora_scale = (
734826
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
735827
)
828+
do_true_cfg = true_cfg > 1 and negative_prompt is not None
736829
(
737830
prompt_embeds,
738831
pooled_prompt_embeds,
739832
text_ids,
833+
negative_prompt_embeds,
834+
negative_pooled_prompt_embeds,
740835
) = self.encode_prompt(
741836
prompt=prompt,
742837
prompt_2=prompt_2,
838+
negative_prompt=negative_prompt,
839+
negative_prompt_2=negative_prompt_2,
743840
prompt_embeds=prompt_embeds,
744841
pooled_prompt_embeds=pooled_prompt_embeds,
842+
negative_prompt_embeds=negative_prompt_embeds,
843+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
745844
device=device,
746845
num_images_per_prompt=num_images_per_prompt,
747846
max_sequence_length=max_sequence_length,
748847
lora_scale=lora_scale,
848+
do_true_cfg=do_true_cfg,
749849
)
750850

751851
# 4. Prepare latent variables
@@ -788,23 +888,43 @@ def __call__(
788888
else:
789889
guidance = None
790890

891+
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
892+
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
893+
):
894+
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
895+
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
896+
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
897+
):
898+
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
899+
900+
if self.joint_attention_kwargs is None:
901+
self._joint_attention_kwargs = {}
902+
903+
image_embeds = None
904+
negative_image_embeds = None
791905
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
792906
image_embeds = self.prepare_ip_adapter_image_embeds(
793907
ip_adapter_image,
794908
ip_adapter_image_embeds,
795909
device,
796910
batch_size * num_images_per_prompt,
797911
)
798-
if self.joint_attention_kwargs is None:
799-
self._joint_attention_kwargs = {}
800912
self._joint_attention_kwargs["image_projection"] = image_embeds
913+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
914+
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
915+
negative_ip_adapter_image,
916+
negative_ip_adapter_image_embeds,
917+
device,
918+
batch_size * num_images_per_prompt,
919+
)
801920

802921
# 6. Denoising loop
803922
with self.progress_bar(total=num_inference_steps) as progress_bar:
804923
for i, t in enumerate(timesteps):
805924
if self.interrupt:
806925
continue
807926

927+
self._joint_attention_kwargs["image_projection"] = image_embeds
808928
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
809929
timestep = t.expand(latents.shape[0]).to(latents.dtype)
810930

@@ -820,6 +940,21 @@ def __call__(
820940
return_dict=False,
821941
)[0]
822942

943+
if do_true_cfg:
944+
self._joint_attention_kwargs["image_projection"] = negative_image_embeds
945+
neg_noise_pred = self.transformer(
946+
hidden_states=latents,
947+
timestep=timestep / 1000,
948+
guidance=guidance,
949+
pooled_projections=negative_pooled_prompt_embeds,
950+
encoder_hidden_states=negative_prompt_embeds,
951+
txt_ids=text_ids,
952+
img_ids=latent_image_ids,
953+
joint_attention_kwargs=self.joint_attention_kwargs,
954+
return_dict=False,
955+
)[0]
956+
noise_pred = neg_noise_pred + true_cfg * (noise_pred - neg_noise_pred)
957+
823958
# compute the previous noisy sample x_t -> x_t-1
824959
latents_dtype = latents.dtype
825960
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

tests/pipelines/flux/test_pipeline_flux.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ def get_dummy_components(self):
9191
"tokenizer_2": tokenizer_2,
9292
"transformer": transformer,
9393
"vae": vae,
94+
"image_encoder": None,
95+
"feature_extractor": None,
9496
}
9597

9698
def get_dummy_inputs(self, device, seed=0):

0 commit comments

Comments
 (0)