Skip to content

Commit a562806

Browse files
committed
style
1 parent 4483400 commit a562806

File tree

8 files changed

+175
-137
lines changed

8 files changed

+175
-137
lines changed

src/diffusers/image_processor.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -842,11 +842,12 @@ class InpaintProcessor(ConfigMixin):
842842
"""
843843
Image processor for inpainting image and mask.
844844
"""
845+
845846
config_name = CONFIG_NAME
846847

847848
@register_to_config
848849
def __init__(
849-
self,
850+
self,
850851
do_resize: bool = True,
851852
vae_scale_factor: int = 8,
852853
vae_latent_channels: int = 4,
@@ -855,42 +856,40 @@ def __init__(
855856
do_normalize: bool = True,
856857
do_binarize: bool = False,
857858
do_convert_grayscale: bool = False,
858-
mask_do_normalize: bool = False,
859-
mask_do_binarize: bool = True,
859+
mask_do_normalize: bool = False,
860+
mask_do_binarize: bool = True,
860861
mask_do_convert_grayscale: bool = True,
861-
):
862-
862+
):
863863
super().__init__()
864864

865865
self._image_processor = VaeImageProcessor(
866866
do_resize=do_resize,
867-
vae_scale_factor=vae_scale_factor,
867+
vae_scale_factor=vae_scale_factor,
868868
vae_latent_channels=vae_latent_channels,
869869
resample=resample,
870870
reducing_gap=reducing_gap,
871871
do_normalize=do_normalize,
872872
do_binarize=do_binarize,
873873
do_convert_grayscale=do_convert_grayscale,
874-
)
874+
)
875875
self._mask_processor = VaeImageProcessor(
876876
do_resize=do_resize,
877-
vae_scale_factor=vae_scale_factor,
877+
vae_scale_factor=vae_scale_factor,
878878
vae_latent_channels=vae_latent_channels,
879879
resample=resample,
880880
reducing_gap=reducing_gap,
881-
do_normalize=mask_do_normalize,
882-
do_binarize=mask_do_binarize,
883-
do_convert_grayscale=mask_do_convert_grayscale,
884-
)
881+
do_normalize=mask_do_normalize,
882+
do_binarize=mask_do_binarize,
883+
do_convert_grayscale=mask_do_convert_grayscale,
884+
)
885885

886-
887886
def preprocess(
888887
self,
889888
image: PIL.Image.Image,
890889
mask: PIL.Image.Image = None,
891-
height:int = None,
892-
width:int = None,
893-
padding_mask_crop:Optional[int] = None,
890+
height: int = None,
891+
width: int = None,
892+
padding_mask_crop: Optional[int] = None,
894893
) -> Tuple[torch.Tensor, torch.Tensor]:
895894
"""
896895
Preprocess the image and mask.
@@ -903,14 +902,12 @@ def preprocess(
903902
return self._image_processor.preprocess(image, height=height, width=width)
904903

905904
if padding_mask_crop is not None:
906-
crops_coords = self._image_processor.get_crop_region(
907-
mask, width, height, pad=padding_mask_crop
908-
)
905+
crops_coords = self._image_processor.get_crop_region(mask, width, height, pad=padding_mask_crop)
909906
resize_mode = "fill"
910907
else:
911908
crops_coords = None
912909
resize_mode = "default"
913-
910+
914911
processed_image = self._image_processor.preprocess(
915912
image,
916913
height=height,
@@ -919,7 +916,6 @@ def preprocess(
919916
resize_mode=resize_mode,
920917
)
921918

922-
923919
processed_mask = self._mask_processor.preprocess(
924920
mask,
925921
height=height,
@@ -928,7 +924,6 @@ def preprocess(
928924
crops_coords=crops_coords,
929925
)
930926

931-
932927
if crops_coords is not None:
933928
postprocessing_kwargs = {
934929
"crops_coords": crops_coords,
@@ -944,7 +939,6 @@ def preprocess(
944939

945940
return processed_image, processed_mask, postprocessing_kwargs
946941

947-
948942
def postprocess(
949943
self,
950944
image: torch.Tensor,
@@ -965,10 +959,10 @@ def postprocess(
965959
raise ValueError("original_image and original_mask must be provided if crops_coords is provided")
966960

967961
elif crops_coords is not None:
968-
image = [self._image_processor.apply_overlay(
969-
original_mask, original_image, i, crops_coords
970-
) for i in image]
971-
962+
image = [
963+
self._image_processor.apply_overlay(original_mask, original_image, i, crops_coords) for i in image
964+
]
965+
972966
return image
973967

974968

src/diffusers/modular_pipelines/qwenimage/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
_import_structure["modular_blocks"] = [
2626
"ALL_BLOCKS",
2727
"CONTROLNET_BLOCKS",
28-
"TEXT2IMAGE_BLOCKS",
2928
"INPAINT_BLOCKS",
29+
"TEXT2IMAGE_BLOCKS",
3030
]
3131
_import_structure["modular_pipeline"] = ["QwenImageModularPipeline"]
3232

@@ -43,8 +43,8 @@
4343
from .modular_blocks import (
4444
ALL_BLOCKS,
4545
CONTROLNET_BLOCKS,
46-
TEXT2IMAGE_BLOCKS,
4746
INPAINT_BLOCKS,
47+
TEXT2IMAGE_BLOCKS,
4848
)
4949
from .modular_pipeline import QwenImageModularPipeline
5050
else:

src/diffusers/modular_pipelines/qwenimage/before_denoise.py

Lines changed: 69 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@
1818
import numpy as np
1919
import torch
2020

21-
from ...configuration_utils import FrozenDict
22-
from ...image_processor import VaeImageProcessor
2321
from ...models import QwenImageControlNetModel, QwenImageMultiControlNetModel
24-
from ...pipelines.qwenimage.pipeline_qwenimage_edit import calculate_dimensions
2522
from ...schedulers import FlowMatchEulerDiscreteScheduler
2623
from ...utils.torch_utils import randn_tensor, unwrap_module
2724
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
@@ -243,45 +240,62 @@ def expected_components(self) -> List[ComponentSpec]:
243240
return [
244241
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
245242
]
246-
243+
247244
@property
248245
def inputs(self) -> List[InputParam]:
249246
return [
250-
InputParam(name="latents", required=True, type_hint=torch.Tensor, description="The initial random noised, can be generated in prepare latent step."),
251-
InputParam(name="image_latents", required=True, type_hint=torch.Tensor, description="The image latents to use for the denoising process. Can be generated in vae encoder + pack latents step."),
252-
InputParam(name="timesteps", required=True, type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",),
253-
InputParam(name="batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in expand textinput step."),
247+
InputParam(
248+
name="latents",
249+
required=True,
250+
type_hint=torch.Tensor,
251+
description="The initial random noised, can be generated in prepare latent step.",
252+
),
253+
InputParam(
254+
name="image_latents",
255+
required=True,
256+
type_hint=torch.Tensor,
257+
description="The image latents to use for the denoising process. Can be generated in vae encoder + pack latents step.",
258+
),
259+
InputParam(
260+
name="timesteps",
261+
required=True,
262+
type_hint=torch.Tensor,
263+
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
264+
),
265+
InputParam(
266+
name="batch_size",
267+
required=True,
268+
type_hint=int,
269+
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in expand textinput step.",
270+
),
254271
InputParam(name="num_images_per_prompt", required=True),
255272
]
256273

257274
@property
258275
def intermediate_outputs(self) -> List[OutputParam]:
259276
return [
260-
OutputParam(name="initial_noise", type_hint=torch.Tensor, description="The initial random noised used for inpainting denoising."),
277+
OutputParam(
278+
name="initial_noise",
279+
type_hint=torch.Tensor,
280+
description="The initial random noised used for inpainting denoising.",
281+
),
261282
]
262-
263-
283+
264284
@staticmethod
265285
def check_inputs(image_latents, latents, batch_size):
266-
267286
if image_latents.shape[0] != batch_size:
268287
raise ValueError(
269288
f"`image_latents` must have have batch size {batch_size}, but got {image_latents.shape[0]}"
270289
)
271290

272291
if image_latents.ndim != 3:
273292
raise ValueError(f"`image_latents` must have 3 dimensions (patchified), but got {image_latents.ndim}")
274-
275-
293+
276294
if latents.shape[0] != batch_size:
277-
raise ValueError(
278-
f"`latents` must have have batch size {batch_size}, but got {latents.shape[0]}"
279-
)
280-
281-
295+
raise ValueError(f"`latents` must have have batch size {batch_size}, but got {latents.shape[0]}")
296+
282297
@torch.no_grad()
283298
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
284-
285299
block_state = self.get_block_state(state)
286300
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
287301

@@ -290,43 +304,52 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
290304
latents=block_state.latents,
291305
batch_size=final_batch_size,
292306
)
293-
307+
294308
# prepare latent timestep
295309
latent_timestep = block_state.timesteps[:1].repeat(final_batch_size)
296-
310+
297311
# make copy of initial_noise
298312
block_state.initial_noise = block_state.latents
299313

300314
# scale noise
301-
block_state.latents = components.scheduler.scale_noise(block_state.image_latents, latent_timestep, block_state.latents)
315+
block_state.latents = components.scheduler.scale_noise(
316+
block_state.image_latents, latent_timestep, block_state.latents
317+
)
302318

303319
self.set_block_state(state, block_state)
304-
305-
return components, state
320+
321+
return components, state
306322

307323

308324
class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks):
309325
model_name = "qwenimage"
310-
326+
311327
@property
312328
def description(self) -> str:
313329
return "Step that create the mask latents for the inpainting process. Should be run with the pachify latents step."
314-
330+
315331
@property
316332
def inputs(self) -> List[InputParam]:
317333
return [
318-
InputParam(name="mask_image", required=True, type_hint=torch.Tensor, description="The mask to use for the inpainting process."),
334+
InputParam(
335+
name="mask_image",
336+
required=True,
337+
type_hint=torch.Tensor,
338+
description="The mask to use for the inpainting process.",
339+
),
319340
InputParam(name="height", required=True),
320341
InputParam(name="width", required=True),
321342
InputParam(name="dtype", required=True),
322343
]
323-
344+
324345
@property
325346
def intermediate_outputs(self) -> List[OutputParam]:
326347
return [
327-
OutputParam(name="mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process."),
348+
OutputParam(
349+
name="mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process."
350+
),
328351
]
329-
352+
330353
@torch.no_grad()
331354
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
332355
block_state = self.get_block_state(state)
@@ -342,14 +365,14 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
342365
block_state.mask = torch.nn.functional.interpolate(
343366
block_state.mask_image,
344367
size=(height_latents, width_latents),
345-
)
368+
)
346369

347370
block_state.mask = block_state.mask.unsqueeze(2)
348371
block_state.mask = block_state.mask.repeat(1, components.num_channels_latents, 1, 1, 1)
349372
block_state.mask = block_state.mask.to(device=device, dtype=block_state.dtype)
350-
373+
351374
self.set_block_state(state, block_state)
352-
375+
353376
return components, state
354377

355378

@@ -381,14 +404,14 @@ def __init__(self, input_names: List[str] = ["image_latents"]):
381404
input_names = [input_names]
382405
self._latents_input_names = input_names
383406
super().__init__()
384-
407+
385408
@staticmethod
386409
def check_input_shape(latents_input, latents_input_name, batch_size):
387410
if latents_input is not None and latents_input.shape[0] != 1 and latents_input.shape[0] != batch_size:
388411
raise ValueError(
389412
f"`{latents_input_name}` must have have batch size 1 or {batch_size}, but got {latents_input.shape[0]}"
390413
)
391-
414+
392415
if latents_input.ndim != 5 and latents_input.ndim != 4:
393416
raise ValueError(f"`{latents_input_name}` must have 4 or 5 dimensions, but got {latents_input.ndim}")
394417

@@ -526,11 +549,12 @@ def inputs(self) -> List[InputParam]:
526549
def intermediate_outputs(self) -> List[OutputParam]:
527550
return [
528551
OutputParam(
529-
name="timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Can be generated in set_timesteps step."
552+
name="timesteps",
553+
type_hint=torch.Tensor,
554+
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
530555
),
531556
]
532557

533-
534558
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
535559
block_state = self.get_block_state(state)
536560

@@ -609,9 +633,14 @@ def intermediate_outputs(self) -> List[OutputParam]:
609633
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
610634
block_state = self.get_block_state(state)
611635

612-
613636
block_state.img_shapes = [
614-
[(1, block_state.height // components.vae_scale_factor // 2, block_state.width // components.vae_scale_factor // 2)]
637+
[
638+
(
639+
1,
640+
block_state.height // components.vae_scale_factor // 2,
641+
block_state.width // components.vae_scale_factor // 2,
642+
)
643+
]
615644
* block_state.batch_size
616645
]
617646
block_state.txt_seq_lens = (

src/diffusers/modular_pipelines/qwenimage/decoders.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def intermediate_outputs(self) -> List[str]:
103103
def check_inputs(output_type):
104104
if output_type not in ["pil", "np", "pt"]:
105105
raise ValueError(f"Invalid output_type: {output_type}")
106-
106+
107107
def __init__(self, include_image_processor: bool = True):
108108
self._include_image_processor = include_image_processor
109109
super().__init__()
@@ -118,7 +118,9 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
118118
block_state.width = block_state.width or components.default_width
119119

120120
# YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular
121-
block_state.latents = unpack_latents(block_state.latents, block_state.height, block_state.width, components.vae_scale_factor)
121+
block_state.latents = unpack_latents(
122+
block_state.latents, block_state.height, block_state.width, components.vae_scale_factor
123+
)
122124
block_state.latents = block_state.latents.to(components.vae.dtype)
123125

124126
latents_mean = (
@@ -131,7 +133,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
131133
).to(block_state.latents.device, block_state.latents.dtype)
132134
block_state.latents = block_state.latents / latents_std + latents_mean
133135
block_state.images = components.vae.decode(block_state.latents, return_dict=False)[0][:, :, 0]
134-
136+
135137
if self._include_image_processor:
136138
block_state.images = components.image_processor.postprocess(
137139
block_state.images, output_type=block_state.output_type

0 commit comments

Comments
 (0)