Skip to content

Commit b89cc40

Browse files
committed
up
1 parent 2d5d876 commit b89cc40

File tree

4 files changed

+163
-164
lines changed

4 files changed

+163
-164
lines changed

src/diffusers/modular_pipelines/qwenimage/decoders.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch
2020

2121
from ...configuration_utils import FrozenDict
22-
from ...image_processor import VaeImageProcessor
22+
from ...image_processor import InpaintProcessor, VaeImageProcessor
2323
from ...models import AutoencoderKLQwenImage
2424
from ...utils import logging
2525
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
@@ -141,3 +141,45 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
141141

142142
self.set_block_state(state, block_state)
143143
return components, state
144+
145+
146+
class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks):
147+
model_name = "qwenimage"
148+
149+
@property
150+
def description(self) -> str:
151+
return "postprocess the generated image, optional apply the mask overally to the original image.."
152+
153+
@property
154+
def expected_components(self) -> List[ComponentSpec]:
155+
return [
156+
ComponentSpec(
157+
"image_mask_processor",
158+
InpaintProcessor,
159+
config=FrozenDict({"vae_scale_factor": 16}),
160+
default_creation_method="from_config",
161+
),
162+
]
163+
164+
@property
165+
def inputs(self) -> List[InputParam]:
166+
return [
167+
InputParam("images", required=True, description="the generated image from decoders step"),
168+
InputParam("original_image"),
169+
InputParam("original_mask"),
170+
InputParam("crop_coords"),
171+
]
172+
173+
@torch.no_grad()
174+
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
175+
block_state = self.get_block_state(state)
176+
177+
block_state.images = components.image_mask_processor.postprocess(
178+
image=block_state.images,
179+
original_image=block_state.original_image,
180+
original_mask=block_state.original_mask,
181+
crops_coords=block_state.crop_coords,
182+
)
183+
184+
self.set_block_state(state, block_state)
185+
return components, state

src/diffusers/modular_pipelines/qwenimage/encoders.py

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import List, Optional, Union
15+
from typing import List, Optional, Tuple, Union
1616

1717
import torch
1818
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
1919

2020
from ...configuration_utils import FrozenDict
2121
from ...guiders import ClassifierFreeGuidance
22-
from ...image_processor import VaeImageProcessor
22+
from ...image_processor import InpaintProcessor, VaeImageProcessor
2323
from ...models import AutoencoderKLQwenImage
24+
from ...pipelines.qwenimage.pipeline_qwenimage_edit import calculate_dimensions
2425
from ...utils import logging
2526
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
2627
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
@@ -166,6 +167,51 @@ def encode_vae_image(
166167
return image_latents
167168

168169

170+
class QwenImageEditResizeStep(ModularPipelineBlocks):
171+
model_name = "qwenimage"
172+
173+
@property
174+
def description(self) -> str:
175+
return "Image Resize step that resize the image to the target area while maintaining the aspect ratio."
176+
177+
@property
178+
def expected_components(self) -> List[ComponentSpec]:
179+
return [
180+
ComponentSpec(
181+
"image_resize_processor",
182+
VaeImageProcessor,
183+
config=FrozenDict({"vae_scale_factor": 16}),
184+
default_creation_method="from_config",
185+
),
186+
]
187+
188+
@property
189+
def inputs(self) -> List[InputParam]:
190+
return [
191+
InputParam(name="image", required=True, type_hint=torch.Tensor, description="The image to resize"),
192+
]
193+
194+
@torch.no_grad()
195+
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
196+
block_state = self.get_block_state(state)
197+
198+
images = block_state.image
199+
if not isinstance(images, list):
200+
images = [images]
201+
202+
image_width, image_height = images[0].size
203+
calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_width / image_height)
204+
205+
resized_images = [
206+
components.image_processor.resize(image, height=calculated_height, width=calculated_width)
207+
for image in images
208+
]
209+
210+
block_state.image = resized_images
211+
self.set_block_state(state, block_state)
212+
return components, state
213+
214+
169215
class QwenImageTextEncoderStep(ModularPipelineBlocks):
170216
model_name = "qwenimage"
171217

@@ -411,6 +457,70 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
411457
return components, state
412458

413459

460+
class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks):
461+
model_name = "qwenimage"
462+
463+
@property
464+
def description(self) -> str:
465+
return "Image Mask step that resize the image to the target area while maintaining the aspect ratio."
466+
467+
@property
468+
def expected_components(self) -> List[ComponentSpec]:
469+
return [
470+
ComponentSpec(
471+
"image_mask_processor",
472+
InpaintProcessor,
473+
config=FrozenDict({"vae_scale_factor": 16}),
474+
default_creation_method="from_config",
475+
),
476+
]
477+
478+
@property
479+
def inputs(self) -> List[InputParam]:
480+
return [
481+
InputParam("image", required=True),
482+
InputParam("mask_image", required=True),
483+
InputParam("height"),
484+
InputParam("width"),
485+
InputParam("padding_mask_crop"),
486+
]
487+
488+
@property
489+
def intermediate_outputs(self) -> List[OutputParam]:
490+
return [
491+
OutputParam(name="original_image", type_hint=torch.Tensor, description="The original image"),
492+
OutputParam(name="original_mask", type_hint=torch.Tensor, description="The original mask"),
493+
OutputParam(
494+
name="crop_coords",
495+
type_hint=List[Tuple[int, int]],
496+
description="The crop coordinates to use for the preprocess/postprocess of the image and mask",
497+
),
498+
]
499+
500+
@torch.no_grad()
501+
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
502+
block_state = self.get_block_state(state)
503+
504+
block_state.height = block_state.height or components.default_height
505+
block_state.width = block_state.width or components.default_width
506+
507+
block_state.image, block_state.mask_image, postprocessing_kwargs = components.image_mask_processor.preprocess(
508+
image=block_state.image,
509+
mask=block_state.mask_image,
510+
height=block_state.height,
511+
width=block_state.width,
512+
padding_mask_crop=block_state.padding_mask_crop,
513+
)
514+
515+
if postprocessing_kwargs:
516+
block_state.original_image = postprocessing_kwargs["original_image"]
517+
block_state.original_mask = postprocessing_kwargs["original_mask"]
518+
block_state.crop_coords = postprocessing_kwargs["crops_coords"]
519+
520+
self.set_block_state(state, block_state)
521+
return components, state
522+
523+
414524
class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
415525
model_name = "qwenimage"
416526

src/diffusers/modular_pipelines/qwenimage/input_output_processor.py renamed to src/diffusers/modular_pipelines/qwenimage/inputs.py

Lines changed: 2 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -12,127 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import List, Tuple
15+
from typing import List
1616

1717
import torch
1818

19-
from ...configuration_utils import FrozenDict
20-
from ...image_processor import InpaintProcessor, VaeImageProcessor
21-
from ...pipelines.qwenimage.pipeline_qwenimage_edit import calculate_dimensions
2219
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
23-
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
20+
from ..modular_pipeline_utils import InputParam, OutputParam
2421
from .modular_pipeline import QwenImageModularPipeline
2522

2623

27-
class QwenImageEditResizeStep(ModularPipelineBlocks):
28-
model_name = "qwenimage"
29-
30-
@property
31-
def description(self) -> str:
32-
return "Image Resize step that resize the image to the target area while maintaining the aspect ratio."
33-
34-
@property
35-
def expected_components(self) -> List[ComponentSpec]:
36-
return [
37-
ComponentSpec(
38-
"image_resize_processor",
39-
VaeImageProcessor,
40-
config=FrozenDict({"vae_scale_factor": 16}),
41-
default_creation_method="from_config",
42-
),
43-
]
44-
45-
@property
46-
def inputs(self) -> List[InputParam]:
47-
return [
48-
InputParam(name="image", required=True, type_hint=torch.Tensor, description="The image to resize"),
49-
]
50-
51-
@torch.no_grad()
52-
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
53-
block_state = self.get_block_state(state)
54-
55-
images = block_state.image
56-
if not isinstance(images, list):
57-
images = [images]
58-
59-
image_width, image_height = images[0].size
60-
calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_width / image_height)
61-
62-
resized_images = [
63-
components.image_processor.resize(image, height=calculated_height, width=calculated_width)
64-
for image in images
65-
]
66-
67-
block_state.image = resized_images
68-
self.set_block_state(state, block_state)
69-
return components, state
70-
71-
72-
class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks):
73-
model_name = "qwenimage"
74-
75-
@property
76-
def description(self) -> str:
77-
return "Image Mask step that resize the image to the target area while maintaining the aspect ratio."
78-
79-
@property
80-
def expected_components(self) -> List[ComponentSpec]:
81-
return [
82-
ComponentSpec(
83-
"image_mask_processor",
84-
InpaintProcessor,
85-
config=FrozenDict({"vae_scale_factor": 16}),
86-
default_creation_method="from_config",
87-
),
88-
]
89-
90-
@property
91-
def inputs(self) -> List[InputParam]:
92-
return [
93-
InputParam("image", required=True),
94-
InputParam("mask_image", required=True),
95-
InputParam("height"),
96-
InputParam("width"),
97-
InputParam("padding_mask_crop"),
98-
]
99-
100-
@property
101-
def intermediate_outputs(self) -> List[OutputParam]:
102-
return [
103-
OutputParam(name="original_image", type_hint=torch.Tensor, description="The original image"),
104-
OutputParam(name="original_mask", type_hint=torch.Tensor, description="The original mask"),
105-
OutputParam(
106-
name="crop_coords",
107-
type_hint=List[Tuple[int, int]],
108-
description="The crop coordinates to use for the preprocess/postprocess of the image and mask",
109-
),
110-
]
111-
112-
@torch.no_grad()
113-
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
114-
block_state = self.get_block_state(state)
115-
116-
block_state.height = block_state.height or components.default_height
117-
block_state.width = block_state.width or components.default_width
118-
119-
block_state.image, block_state.mask_image, postprocessing_kwargs = components.image_mask_processor.preprocess(
120-
image=block_state.image,
121-
mask=block_state.mask_image,
122-
height=block_state.height,
123-
width=block_state.width,
124-
padding_mask_crop=block_state.padding_mask_crop,
125-
)
126-
127-
if postprocessing_kwargs:
128-
block_state.original_image = postprocessing_kwargs["original_image"]
129-
block_state.original_mask = postprocessing_kwargs["original_mask"]
130-
block_state.crop_coords = postprocessing_kwargs["crops_coords"]
131-
132-
self.set_block_state(state, block_state)
133-
return components, state
134-
135-
13624
class QwenImageInputsDynamicStep(ModularPipelineBlocks):
13725
model_name = "qwenimage"
13826

@@ -322,45 +210,3 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
322210
self.set_block_state(state, block_state)
323211

324212
return components, state
325-
326-
327-
class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks):
328-
model_name = "qwenimage"
329-
330-
@property
331-
def description(self) -> str:
332-
return "postprocess the generated image, optional apply the mask overally to the original image.."
333-
334-
@property
335-
def expected_components(self) -> List[ComponentSpec]:
336-
return [
337-
ComponentSpec(
338-
"image_mask_processor",
339-
InpaintProcessor,
340-
config=FrozenDict({"vae_scale_factor": 16}),
341-
default_creation_method="from_config",
342-
),
343-
]
344-
345-
@property
346-
def inputs(self) -> List[InputParam]:
347-
return [
348-
InputParam("images", required=True, description="the generated image from decoders step"),
349-
InputParam("original_image"),
350-
InputParam("original_mask"),
351-
InputParam("crop_coords"),
352-
]
353-
354-
@torch.no_grad()
355-
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
356-
block_state = self.get_block_state(state)
357-
358-
block_state.images = components.image_mask_processor.postprocess(
359-
image=block_state.images,
360-
original_image=block_state.original_image,
361-
original_mask=block_state.original_mask,
362-
crops_coords=block_state.crop_coords,
363-
)
364-
365-
self.set_block_state(state, block_state)
366-
return components, state

src/diffusers/modular_pipelines/qwenimage/modular_blocks.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,21 @@
2626
QwenImageSetTimestepsStep,
2727
QwenImageSetTimestepsWithStrengthStep,
2828
)
29-
from .decoders import QwenImageDecodeDynamicStep
29+
from .decoders import QwenImageDecodeDynamicStep, QwenImageInpaintProcessImagesOutputStep
3030
from .denoise import (
3131
QwenImageControlNetLoopBeforeDenoiser,
3232
QwenImageDenoiseStep,
3333
QwenImageEditDenoiseStep,
3434
QwenImageInpaintDenoiseStep,
3535
)
36-
from .encoders import QwenImageEditTextEncoderStep, QwenImageTextEncoderStep, QwenImageVaeEncoderDynamicStep
37-
from .input_output_processor import (
36+
from .encoders import (
3837
QwenImageEditResizeStep,
38+
QwenImageEditTextEncoderStep,
3939
QwenImageInpaintProcessImagesInputStep,
40-
QwenImageInpaintProcessImagesOutputStep,
41-
QwenImageInputsDynamicStep,
40+
QwenImageTextEncoderStep,
41+
QwenImageVaeEncoderDynamicStep,
4242
)
43+
from .inputs import QwenImageInputsDynamicStep
4344

4445

4546
logger = logging.get_logger(__name__)

0 commit comments

Comments
 (0)