Skip to content

Commit 8c1492b

Browse files
Apply style fixes
1 parent a3f78ac commit 8c1492b

File tree

1 file changed

+33
-20
lines changed

1 file changed

+33
-20
lines changed

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,19 @@
4646
>>> import torch
4747
>>> from diffusers.utils import load_image
4848
>>> from diffusers import QwenImageControlNetModel, QwenImageControlNetInpaintPipeline
49+
4950
>>> base_model_path = "Qwen/Qwen-Image"
5051
>>> controlnet_model_path = "InstantX/Qwen-Image-ControlNet-Inpainting"
5152
>>> controlnet = QwenImageControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.bfloat16)
52-
>>> pipe = QwenImageControlNetInpaintPipeline.from_pretrained(base_model_path, controlnet=controlnet, torch_dtype=torch.bfloat16).to("cuda")
53-
>>> image = load_image("https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting/resolve/main/assets/images/image1.png")
54-
>>> mask_image = load_image("https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting/resolve/main/assets/masks/mask1.png")
53+
>>> pipe = QwenImageControlNetInpaintPipeline.from_pretrained(
54+
... base_model_path, controlnet=controlnet, torch_dtype=torch.bfloat16
55+
... ).to("cuda")
56+
>>> image = load_image(
57+
... "https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting/resolve/main/assets/images/image1.png"
58+
... )
59+
>>> mask_image = load_image(
60+
... "https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting/resolve/main/assets/masks/mask1.png"
61+
... )
5562
>>> prompt = "一辆绿色的出租车行驶在路上"
5663
>>> result = pipe(
5764
... prompt=prompt,
@@ -80,6 +87,7 @@ def calculate_shift(
8087
mu = image_seq_len * m + b
8188
return mu
8289

90+
8391
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
8492
def retrieve_latents(
8593
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
@@ -93,6 +101,7 @@ def retrieve_latents(
93101
else:
94102
raise AttributeError("Could not access latents of provided encoder_output")
95103

104+
96105
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
97106
def retrieve_timesteps(
98107
scheduler,
@@ -105,6 +114,7 @@ def retrieve_timesteps(
105114
r"""
106115
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
107116
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
117+
108118
Args:
109119
scheduler (`SchedulerMixin`):
110120
The scheduler to get timesteps from.
@@ -154,6 +164,7 @@ def retrieve_timesteps(
154164
class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
155165
r"""
156166
The QwenImage pipeline for text-to-image generation.
167+
157168
Args:
158169
transformer ([`QwenImageTransformer2DModel`]):
159170
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
@@ -472,7 +483,7 @@ def prepare_image(
472483
image = torch.cat([image] * 2)
473484

474485
return image
475-
486+
476487
# Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet_inpainting.StableDiffusion3ControlNetPipeline.prepare_image_with_mask
477488
def prepare_image_with_mask(
478489
self,
@@ -501,45 +512,47 @@ def prepare_image_with_mask(
501512
repeat_by = num_images_per_prompt
502513

503514
image = image.repeat_interleave(repeat_by, dim=0)
504-
image = image.to(device=device, dtype=dtype) # (bsz, 3, height_ori, width_ori)
515+
image = image.to(device=device, dtype=dtype) # (bsz, 3, height_ori, width_ori)
505516

506517
# Prepare mask
507518
if isinstance(mask, torch.Tensor):
508519
pass
509520
else:
510521
mask = self.mask_processor.preprocess(mask, height=height, width=width)
511522
mask = mask.repeat_interleave(repeat_by, dim=0)
512-
mask = mask.to(device=device, dtype=dtype) # (bsz, 1, height_ori, width_ori)
523+
mask = mask.to(device=device, dtype=dtype) # (bsz, 1, height_ori, width_ori)
513524

514525
if image.ndim == 4:
515526
image = image.unsqueeze(2)
516-
527+
517528
if mask.ndim == 4:
518529
mask = mask.unsqueeze(2)
519530

520531
# Get masked image
521532
masked_image = image.clone()
522-
masked_image[(mask > 0.5).repeat(1, 3, 1, 1, 1)] = -1 # (bsz, 3, 1, height_ori, width_ori)
523-
533+
masked_image[(mask > 0.5).repeat(1, 3, 1, 1, 1)] = -1 # (bsz, 3, 1, height_ori, width_ori)
534+
524535
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
525536
latents_mean = (torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1)).to(device)
526-
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(device)
537+
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
538+
device
539+
)
527540

528541
# Encode to latents
529542
image_latents = self.vae.encode(masked_image.to(self.vae.dtype)).latent_dist.sample()
530-
image_latents = (
531-
image_latents - latents_mean
532-
) * latents_std
533-
image_latents = image_latents.to(dtype) # torch.Size([1, 16, 1, height_ori//8, width_ori//8])
543+
image_latents = (image_latents - latents_mean) * latents_std
544+
image_latents = image_latents.to(dtype) # torch.Size([1, 16, 1, height_ori//8, width_ori//8])
534545

535546
mask = torch.nn.functional.interpolate(
536547
mask, size=(image_latents.shape[-3], image_latents.shape[-2], image_latents.shape[-1])
537548
)
538-
mask = 1 - mask # torch.Size([1, 1, 1, height_ori//8, width_ori//8])
549+
mask = 1 - mask # torch.Size([1, 1, 1, height_ori//8, width_ori//8])
539550

540-
control_image = torch.cat([image_latents, mask], dim=1) # torch.Size([1, 16+1, 1, height_ori//8, width_ori//8])
551+
control_image = torch.cat(
552+
[image_latents, mask], dim=1
553+
) # torch.Size([1, 16+1, 1, height_ori//8, width_ori//8])
541554

542-
control_image = control_image.permute(0, 2, 1, 3, 4) # torch.Size([1, 1, 16+1, height_ori//8, width_ori//8])
555+
control_image = control_image.permute(0, 2, 1, 3, 4) # torch.Size([1, 1, 16+1, height_ori//8, width_ori//8])
543556

544557
# pack
545558
control_image = self._pack_latents(
@@ -608,6 +621,7 @@ def __call__(
608621
):
609622
r"""
610623
Function invoked when calling the pipeline for generation.
624+
611625
Args:
612626
prompt (`str` or `List[str]`, *optional*):
613627
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
@@ -670,8 +684,7 @@ def __call__(
670684
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
671685
`._callback_tensor_inputs` attribute of your pipeline class.
672686
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
673-
Examples:
674-
Returns:
687+
Examples: Returns:
675688
[`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
676689
[`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
677690
returning a tuple, the first element is a list with the generated images.
@@ -839,7 +852,7 @@ def __call__(
839852
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
840853
return_dict=False,
841854
)
842-
855+
843856
with self.transformer.cache_context("cond"):
844857
noise_pred = self.transformer(
845858
hidden_states=latents,

0 commit comments

Comments
 (0)