Skip to content

Commit 43d041a

Browse files
committed
update
1 parent 03165b9 commit 43d041a

File tree

1 file changed

+134
-61
lines changed

1 file changed

+134
-61
lines changed

src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py

Lines changed: 134 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,10 @@ def encode_prompt(
292292
negative_prompt: Union[str, List[str]] = None,
293293
device: Optional[torch.device] = None,
294294
num_images_per_prompt: int = 1,
295-
prompt_embeds: Optional[torch.FloatTensor] = None,
296-
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
295+
prompt_embeds: Optional[torch.Tensor] = None,
296+
negative_prompt_embeds: Optional[torch.Tensor] = None,
297+
prompt_attention_mask: Optional[torch.Tensor] = None,
298+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
297299
do_classifier_free_guidance: bool = True,
298300
max_sequence_length: int = 512,
299301
lora_scale: Optional[float] = None,
@@ -310,7 +312,7 @@ def encode_prompt(
310312
torch device
311313
num_images_per_prompt (`int`):
312314
number of images that should be generated per prompt
313-
prompt_embeds (`torch.FloatTensor`, *optional*):
315+
prompt_embeds (`torch.Tensor`, *optional*):
314316
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
315317
provided, text embeddings will be generated from `prompt` input argument.
316318
lora_scale (`float`, *optional*):
@@ -335,7 +337,7 @@ def encode_prompt(
335337
batch_size = prompt_embeds.shape[0]
336338

337339
if prompt_embeds is None:
338-
prompt_embeds = self._get_t5_prompt_embeds(
340+
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
339341
prompt=prompt,
340342
num_images_per_prompt=num_images_per_prompt,
341343
max_sequence_length=max_sequence_length,
@@ -365,20 +367,28 @@ def encode_prompt(
365367
" the batch size of `prompt`."
366368
)
367369

368-
negative_prompt_embeds = self._get_t5_prompt_embeds(
370+
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
369371
prompt=negative_prompt,
370372
num_images_per_prompt=num_images_per_prompt,
371373
max_sequence_length=max_sequence_length,
372374
device=device,
373375
)
376+
374377
negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
375378

376379
if self.text_encoder is not None:
377380
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
378381
# Retrieve the original scale by scaling back the LoRA layers
379382
unscale_lora_layers(self.text_encoder, lora_scale)
380383

381-
return prompt_embeds, text_ids, negative_prompt_embeds, negative_text_ids
384+
return (
385+
prompt_embeds,
386+
text_ids,
387+
prompt_attention_mask,
388+
negative_prompt_embeds,
389+
negative_text_ids,
390+
negative_prompt_attention_mask,
391+
)
382392

383393
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
384394
def encode_image(self, image, device, num_images_per_prompt):
@@ -392,52 +402,44 @@ def encode_image(self, image, device, num_images_per_prompt):
392402
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
393403
return image_embeds
394404

395-
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds
405+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
396406
def prepare_ip_adapter_image_embeds(
397-
self,
398-
ip_adapter_image: Optional[PipelineImageInput] = None,
399-
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
400-
device: Optional[torch.device] = None,
401-
num_images_per_prompt: int = 1,
402-
do_classifier_free_guidance: bool = True,
403-
) -> torch.Tensor:
404-
"""Prepares image embeddings for use in the IP-Adapter.
407+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
408+
):
409+
device = device or self._execution_device
405410

406-
Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
411+
image_embeds = []
412+
if ip_adapter_image_embeds is None:
413+
if not isinstance(ip_adapter_image, list):
414+
ip_adapter_image = [ip_adapter_image]
407415

408-
Args:
409-
ip_adapter_image (`PipelineImageInput`, *optional*):
410-
The input image to extract features from for IP-Adapter.
411-
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
412-
Precomputed image embeddings.
413-
device: (`torch.device`, *optional*):
414-
Torch device.
415-
num_images_per_prompt (`int`, defaults to 1):
416-
Number of images that should be generated per prompt.
417-
do_classifier_free_guidance (`bool`, defaults to True):
418-
Whether to use classifier free guidance or not.
419-
"""
420-
device = device or self._execution_device
416+
if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
417+
raise ValueError(
418+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
419+
)
421420

422-
if ip_adapter_image_embeds is not None:
423-
if do_classifier_free_guidance:
424-
single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
425-
else:
426-
single_image_embeds = ip_adapter_image_embeds
427-
elif ip_adapter_image is not None:
428-
single_image_embeds = self.encode_image(ip_adapter_image, device)
429-
if do_classifier_free_guidance:
430-
single_negative_image_embeds = torch.zeros_like(single_image_embeds)
421+
for single_ip_adapter_image in ip_adapter_image:
422+
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
423+
image_embeds.append(single_image_embeds[None, :])
431424
else:
432-
raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
425+
if not isinstance(ip_adapter_image_embeds, list):
426+
ip_adapter_image_embeds = [ip_adapter_image_embeds]
433427

434-
image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
428+
if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
429+
raise ValueError(
430+
f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
431+
)
435432

436-
if do_classifier_free_guidance:
437-
negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
438-
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
433+
for single_image_embeds in ip_adapter_image_embeds:
434+
image_embeds.append(single_image_embeds)
435+
436+
ip_adapter_image_embeds = []
437+
for single_image_embeds in image_embeds:
438+
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
439+
single_image_embeds = single_image_embeds.to(device=device)
440+
ip_adapter_image_embeds.append(single_image_embeds)
439441

440-
return image_embeds.to(device=device)
442+
return ip_adapter_image_embeds
441443

442444
def check_inputs(
443445
self,
@@ -448,6 +450,8 @@ def check_inputs(
448450
negative_prompt=None,
449451
prompt_embeds=None,
450452
negative_prompt_embeds=None,
453+
prompt_attention_mask=None,
454+
negative_prompt_attention_mask=None,
451455
callback_on_step_end_tensor_inputs=None,
452456
max_sequence_length=None,
453457
):
@@ -483,6 +487,15 @@ def check_inputs(
483487
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
484488
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
485489
)
490+
if prompt_attention_mask is not None and negative_prompt_attention_mask is None:
491+
raise ValueError(
492+
"Cannot provide `prompt_attention_mask` without also providing `negative_prompt_attention_mask`"
493+
)
494+
495+
if negative_prompt_attention_mask is not None and prompt_attention_mask is None:
496+
raise ValueError(
497+
"Cannot provide `negative_prompt_attention_mask` without also providing `prompt_attention_mask`"
498+
)
486499

487500
if max_sequence_length is not None and max_sequence_length > 512:
488501
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@@ -591,7 +604,7 @@ def prepare_latents(
591604
height = 2 * (int(height) // (self.vae_scale_factor * 2))
592605
width = 2 * (int(width) // (self.vae_scale_factor * 2))
593606
shape = (batch_size, num_channels_latents, height, width)
594-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
607+
latent_image_ids = self._prepare_latent_image_ids(height // 2, width // 2, device, dtype)
595608

596609
if latents is not None:
597610
return latents.to(device=device, dtype=dtype), latent_image_ids
@@ -617,6 +630,25 @@ def prepare_latents(
617630
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
618631
return latents, latent_image_ids
619632

633+
def _prepare_attention_mask(
634+
self,
635+
batch_size,
636+
sequence_length,
637+
dtype,
638+
attention_mask=None,
639+
):
640+
if attention_mask is None:
641+
return attention_mask
642+
643+
# Extend the prompt attention mask to account for image tokens in the final sequence
644+
attention_mask = torch.cat(
645+
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)],
646+
dim=1,
647+
)
648+
attention_mask = attention_mask.to(dtype)
649+
650+
return attention_mask
651+
620652
@property
621653
def guidance_scale(self):
622654
return self._guidance_scale
@@ -656,13 +688,15 @@ def __call__(
656688
strength: float = 0.8,
657689
num_images_per_prompt: Optional[int] = 1,
658690
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
659-
latents: Optional[torch.FloatTensor] = None,
660-
prompt_embeds: Optional[torch.FloatTensor] = None,
691+
latents: Optional[torch.Tensor] = None,
692+
prompt_embeds: Optional[torch.Tensor] = None,
661693
ip_adapter_image: Optional[PipelineImageInput] = None,
662694
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
663695
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
664696
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
665-
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
697+
negative_prompt_embeds: Optional[torch.Tensor] = None,
698+
prompt_attention_mask: Optional[torch.Tensor] = None,
699+
negative_prompt_attention_mask: Optional[torch.tensor] = None,
666700
output_type: Optional[str] = "pil",
667701
return_dict: bool = True,
668702
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -703,11 +737,11 @@ def __call__(
703737
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
704738
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
705739
to make generation deterministic.
706-
latents (`torch.FloatTensor`, *optional*):
740+
latents (`torch.Tensor`, *optional*):
707741
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
708742
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
709743
tensor will ge generated by sampling using the supplied random `generator`.
710-
prompt_embeds (`torch.FloatTensor`, *optional*):
744+
prompt_embeds (`torch.Tensor`, *optional*):
711745
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
712746
provided, text embeddings will be generated from `prompt` input argument.
713747
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
@@ -721,7 +755,7 @@ def __call__(
721755
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
722756
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
723757
provided, embeddings are computed from the `ip_adapter_image` input argument.
724-
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
758+
negative_prompt_embeds (`torch.Tensor`, *optional*):
725759
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
726760
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
727761
argument.
@@ -765,6 +799,8 @@ def __call__(
765799
negative_prompt=negative_prompt,
766800
prompt_embeds=prompt_embeds,
767801
negative_prompt_embeds=negative_prompt_embeds,
802+
prompt_attention_mask=prompt_attention_mask,
803+
negative_prompt_attention_mask=negative_prompt_attention_mask,
768804
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
769805
max_sequence_length=max_sequence_length,
770806
)
@@ -794,13 +830,17 @@ def __call__(
794830
(
795831
prompt_embeds,
796832
text_ids,
833+
prompt_attention_mask,
797834
negative_prompt_embeds,
798835
negative_text_ids,
836+
negative_prompt_attention_mask,
799837
) = self.encode_prompt(
800838
prompt=prompt,
801839
negative_prompt=negative_prompt,
802840
prompt_embeds=prompt_embeds,
803841
negative_prompt_embeds=negative_prompt_embeds,
842+
prompt_attention_mask=prompt_attention_mask,
843+
negative_prompt_attention_mask=negative_prompt_attention_mask,
804844
do_classifier_free_guidance=self.do_classifier_free_guidance,
805845
device=device,
806846
num_images_per_prompt=num_images_per_prompt,
@@ -856,20 +896,55 @@ def __call__(
856896
latents,
857897
)
858898

899+
attention_mask = self._prepare_attention_mask(
900+
batch_size=latents.shape[0],
901+
sequence_length=image_seq_len,
902+
dtype=latents.dtype,
903+
attention_mask=prompt_attention_mask,
904+
)
905+
if self.do_classifier_free_guidance and negative_prompt_attention_mask is not None:
906+
negative_attention_mask = self._prepare_attention_mask(
907+
batch_size=latents.shape[0],
908+
sequence_length=image_seq_len,
909+
dtype=latents.dtype,
910+
attention_mask=negative_prompt_attention_mask,
911+
)
912+
attention_mask = torch.cat([negative_attention_mask, attention_mask], dim=0)
913+
859914
# 6. Prepare image embeddings
860-
if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
861-
ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
915+
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
916+
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
917+
):
918+
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
919+
negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
920+
921+
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
922+
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
923+
):
924+
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
925+
ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
926+
927+
image_embeds = None
928+
negative_image_embeds = None
929+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
930+
image_embeds = self.prepare_ip_adapter_image_embeds(
862931
ip_adapter_image,
863932
ip_adapter_image_embeds,
864933
device,
865934
batch_size * num_images_per_prompt,
866-
self.do_classifier_free_guidance,
867935
)
936+
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
937+
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
938+
negative_ip_adapter_image,
939+
negative_ip_adapter_image_embeds,
940+
device,
941+
batch_size * num_images_per_prompt,
942+
)
943+
if self.do_classifier_free_guidance and image_embeds is not None and negative_image_embeds is not None:
944+
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
868945

869-
if self.joint_attention_kwargs is None:
870-
self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
871-
else:
872-
self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
946+
if image_embeds is not None:
947+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
873948

874949
# 6. Denoising loop
875950
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -878,9 +953,6 @@ def __call__(
878953
continue
879954

880955
self._current_timestep = t
881-
if ip_adapter_image_embeds is not None:
882-
self._joint_attention_kwargs["ip_adapter_image_embeds"] = ip_adapter_image_embeds
883-
884956
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
885957

886958
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
@@ -892,6 +964,7 @@ def __call__(
892964
encoder_hidden_states=prompt_embeds,
893965
txt_ids=text_ids,
894966
img_ids=latent_image_ids,
967+
attention_mask=attention_mask,
895968
joint_attention_kwargs=self.joint_attention_kwargs,
896969
return_dict=False,
897970
)[0]

0 commit comments

Comments
 (0)