Skip to content

Commit 4a8ccd7

Browse files
committed
stay up to date with main branch
1 parent 3097624 commit 4a8ccd7

File tree

1 file changed

+77
-42
lines changed

1 file changed

+77
-42
lines changed

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py

Lines changed: 77 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -421,26 +421,27 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
421421

422422
return latents
423423

424+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_img2img.QwenImageImg2ImgPipeline._encode_vae_image
424425
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
425426
if isinstance(generator, list):
426427
image_latents = [
427-
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
428+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
428429
for i in range(image.shape[0])
429430
]
430431
image_latents = torch.cat(image_latents, dim=0)
431432
else:
432-
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
433+
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
434+
433435
latents_mean = (
434436
torch.tensor(self.vae.config.latents_mean)
435-
.view(1, self.latent_channels, 1, 1, 1)
437+
.view(1, self.vae.config.z_dim, 1, 1, 1)
436438
.to(image_latents.device, image_latents.dtype)
437439
)
438-
latents_std = (
439-
torch.tensor(self.vae.config.latents_std)
440-
.view(1, self.latent_channels, 1, 1, 1)
441-
.to(image_latents.device, image_latents.dtype)
440+
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
441+
image_latents.device, image_latents.dtype
442442
)
443-
image_latents = (image_latents - latents_mean) / latents_std
443+
444+
image_latents = (image_latents - latents_mean) * latents_std
444445

445446
return image_latents
446447

@@ -485,6 +486,7 @@ def disable_vae_tiling(self):
485486
"""
486487
self.vae.disable_tiling()
487488

489+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_inpaint.QwenImageInpaintPipeline.prepare_latents
488490
def prepare_latents(
489491
self,
490492
image,
@@ -510,25 +512,32 @@ def prepare_latents(
510512

511513
shape = (batch_size, 1, num_channels_latents, height, width)
512514

513-
image_latents = None
514-
if image is not None:
515-
image = image.to(device=device, dtype=dtype)
516-
if image.shape[1] != self.latent_channels:
517-
image_latents = self._encode_vae_image(image=image, generator=generator)
518-
else:
519-
image_latents = image
520-
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
521-
# expand init_latents for batch_size
522-
additional_image_per_prompt = batch_size // image_latents.shape[0]
523-
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
524-
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
525-
raise ValueError(
526-
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
527-
)
528-
else:
529-
image_latents = torch.cat([image_latents], dim=0)
515+
# If image is [B,C,H,W] -> add T=1. If it's already [B,C,T,H,W], leave it.
516+
if image.dim() == 4:
517+
image = image.unsqueeze(2)
518+
elif image.dim() != 5:
519+
raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.")
520+
521+
if latents is not None:
522+
return latents.to(device=device, dtype=dtype)
523+
524+
image = image.to(device=device, dtype=dtype)
525+
if image.shape[1] != self.latent_channels:
526+
image_latents = self._encode_vae_image(image=image, generator=generator) # [B,z,1,H',W']
527+
else:
528+
image_latents = image
529+
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
530+
# expand init_latents for batch_size
531+
additional_image_per_prompt = batch_size // image_latents.shape[0]
532+
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
533+
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
534+
raise ValueError(
535+
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
536+
)
537+
else:
538+
image_latents = torch.cat([image_latents], dim=0)
530539

531-
image_latents = image_latents.transpose(1, 2) # [B,1,z,H',W']
540+
image_latents = image_latents.transpose(1, 2) # [B,1,z,H',W']
532541

533542
if latents is None:
534543
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
@@ -655,7 +664,7 @@ def __call__(
655664
strength: float = 0.6,
656665
num_inference_steps: int = 50,
657666
sigmas: Optional[List[float]] = None,
658-
guidance_scale: float = 1.0,
667+
guidance_scale: Optional[float] = None,
659668
num_images_per_prompt: int = 1,
660669
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
661670
latents: Optional[torch.Tensor] = None,
@@ -674,6 +683,12 @@ def __call__(
674683
Function invoked when calling the pipeline for generation.
675684
676685
Args:
686+
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
687+
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
688+
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
689+
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
690+
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
691+
latents as `image`, but if passing latents directly it is not encoded again.
677692
prompt (`str` or `List[str]`, *optional*):
678693
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
679694
instead.
@@ -682,7 +697,12 @@ def __call__(
682697
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
683698
not greater than `1`).
684699
true_cfg_scale (`float`, *optional*, defaults to 1.0):
685-
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
700+
true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free
701+
Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of
702+
equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is
703+
enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale
704+
encourages to generate images that are closely linked to the text `prompt`, usually at the expense of
705+
lower image quality.
686706
mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
687707
`Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
688708
are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
@@ -717,17 +737,16 @@ def __call__(
717737
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
718738
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
719739
will be used.
720-
guidance_scale (`float`, *optional*, defaults to 3.5):
721-
Guidance scale as defined in [Classifier-Free Diffusion
722-
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
723-
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
724-
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
725-
the text `prompt`, usually at the expense of lower image quality.
726-
727-
This parameter in the pipeline is there to support future guidance-distilled models when they come up.
728-
Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance,
729-
please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should
730-
enable classifier-free guidance computations.
740+
guidance_scale (`float`, *optional*, defaults to None):
741+
A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
742+
where the guidance scale is applied during inference through noise prediction rescaling, guidance
743+
distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
744+
scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
745+
that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
746+
parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
747+
ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
748+
please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
749+
enable classifier-free guidance computations).
731750
num_images_per_prompt (`int`, *optional*, defaults to 1):
732751
The number of images to generate per prompt.
733752
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -831,11 +850,20 @@ def __call__(
831850
image, height=calculated_height, width=calculated_width, crops_coords=crops_coords, resize_mode=resize_mode
832851
)
833852
image = image.to(dtype=torch.float32)
834-
image = image.unsqueeze(2)
835853

836854
has_neg_prompt = negative_prompt is not None or (
837855
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
838856
)
857+
858+
if true_cfg_scale > 1 and not has_neg_prompt:
859+
logger.warning(
860+
f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
861+
)
862+
elif true_cfg_scale <= 1 and has_neg_prompt:
863+
logger.warning(
864+
" negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
865+
)
866+
839867
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
840868
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
841869
image=prompt_image,
@@ -932,10 +960,17 @@ def __call__(
932960
self._num_timesteps = len(timesteps)
933961

934962
# handle guidance
935-
if self.transformer.config.guidance_embeds:
963+
if self.transformer.config.guidance_embeds and guidance_scale is None:
964+
raise ValueError("guidance_scale is required for guidance-distilled model.")
965+
elif self.transformer.config.guidance_embeds:
936966
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
937967
guidance = guidance.expand(latents.shape[0])
938-
else:
968+
elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
969+
logger.warning(
970+
f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
971+
)
972+
guidance = None
973+
elif not self.transformer.config.guidance_embeds and guidance_scale is None:
939974
guidance = None
940975

941976
if self.attention_kwargs is None:

0 commit comments

Comments
 (0)