Skip to content

Commit 33277c2

Browse files
committed
format
1 parent 66f835b commit 33277c2

File tree

2 files changed

+10
-12
lines changed

2 files changed

+10
-12
lines changed

src/diffusers/models/controlnets/controlnet_qwenimage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,4 +356,4 @@ def forward(
356356
else:
357357
raise ValueError("QwenImageMultiControlNetModel only supports controlnet-union now.")
358358

359-
return control_block_samples
359+
return control_block_samples

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,7 @@ def __init__(
191191
text_encoder: Qwen2_5_VLForConditionalGeneration,
192192
tokenizer: Qwen2Tokenizer,
193193
transformer: QwenImageTransformer2DModel,
194-
controlnet: Union[
195-
QwenImageControlNetModel, QwenImageMultiControlNetModel
196-
],
194+
controlnet: Union[QwenImageControlNetModel, QwenImageMultiControlNetModel],
197195
):
198196
super().__init__()
199197

@@ -701,7 +699,7 @@ def __call__(
701699
height=control_image.shape[3],
702700
width=control_image.shape[4],
703701
).to(dtype=prompt_embeds.dtype, device=device)
704-
702+
705703
else:
706704
if isinstance(self.controlnet, QwenImageMultiControlNetModel):
707705
control_images = []
@@ -723,12 +721,12 @@ def __call__(
723721

724722
# vae encode
725723
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
726-
latents_mean = (torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1)).to(
727-
device
728-
)
729-
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
730-
device
731-
)
724+
latents_mean = (
725+
torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1)
726+
).to(device)
727+
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(
728+
1, self.vae.config.z_dim, 1, 1, 1
729+
).to(device)
732730

733731
control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
734732
control_image_ = (control_image_ - latents_mean) * latents_std
@@ -818,7 +816,7 @@ def __call__(
818816
if isinstance(controlnet_cond_scale, list):
819817
controlnet_cond_scale = controlnet_cond_scale[0]
820818
cond_scale = controlnet_cond_scale * controlnet_keep[i]
821-
819+
822820
# controlnet
823821
controlnet_block_samples = self.controlnet(
824822
hidden_states=latents,

0 commit comments

Comments
 (0)