|  | 
| 12 | 12 | # See the License for the specific language governing permissions and | 
| 13 | 13 | # limitations under the License. | 
| 14 | 14 | 
 | 
| 15 |  | -from typing import List, Tuple | 
|  | 15 | +from typing import List | 
| 16 | 16 | 
 | 
| 17 | 17 | import torch | 
| 18 | 18 | 
 | 
| 19 |  | -from ...configuration_utils import FrozenDict | 
| 20 |  | -from ...image_processor import InpaintProcessor, VaeImageProcessor | 
| 21 |  | -from ...pipelines.qwenimage.pipeline_qwenimage_edit import calculate_dimensions | 
| 22 | 19 | from ..modular_pipeline import ModularPipelineBlocks, PipelineState | 
| 23 |  | -from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam | 
|  | 20 | +from ..modular_pipeline_utils import InputParam, OutputParam | 
| 24 | 21 | from .modular_pipeline import QwenImageModularPipeline | 
| 25 | 22 | 
 | 
| 26 | 23 | 
 | 
| 27 |  | -class QwenImageEditResizeStep(ModularPipelineBlocks): | 
| 28 |  | -    model_name = "qwenimage" | 
| 29 |  | - | 
| 30 |  | -    @property | 
| 31 |  | -    def description(self) -> str: | 
| 32 |  | -        return "Image Resize step that resize the image to the target area while maintaining the aspect ratio." | 
| 33 |  | - | 
| 34 |  | -    @property | 
| 35 |  | -    def expected_components(self) -> List[ComponentSpec]: | 
| 36 |  | -        return [ | 
| 37 |  | -            ComponentSpec( | 
| 38 |  | -                "image_resize_processor", | 
| 39 |  | -                VaeImageProcessor, | 
| 40 |  | -                config=FrozenDict({"vae_scale_factor": 16}), | 
| 41 |  | -                default_creation_method="from_config", | 
| 42 |  | -            ), | 
| 43 |  | -        ] | 
| 44 |  | - | 
| 45 |  | -    @property | 
| 46 |  | -    def inputs(self) -> List[InputParam]: | 
| 47 |  | -        return [ | 
| 48 |  | -            InputParam(name="image", required=True, type_hint=torch.Tensor, description="The image to resize"), | 
| 49 |  | -        ] | 
| 50 |  | - | 
| 51 |  | -    @torch.no_grad() | 
| 52 |  | -    def __call__(self, components: QwenImageModularPipeline, state: PipelineState): | 
| 53 |  | -        block_state = self.get_block_state(state) | 
| 54 |  | - | 
| 55 |  | -        images = block_state.image | 
| 56 |  | -        if not isinstance(images, list): | 
| 57 |  | -            images = [images] | 
| 58 |  | - | 
| 59 |  | -        image_width, image_height = images[0].size | 
| 60 |  | -        calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_width / image_height) | 
| 61 |  | - | 
| 62 |  | -        resized_images = [ | 
| 63 |  | -            components.image_processor.resize(image, height=calculated_height, width=calculated_width) | 
| 64 |  | -            for image in images | 
| 65 |  | -        ] | 
| 66 |  | - | 
| 67 |  | -        block_state.image = resized_images | 
| 68 |  | -        self.set_block_state(state, block_state) | 
| 69 |  | -        return components, state | 
| 70 |  | - | 
| 71 |  | - | 
| 72 |  | -class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks): | 
| 73 |  | -    model_name = "qwenimage" | 
| 74 |  | - | 
| 75 |  | -    @property | 
| 76 |  | -    def description(self) -> str: | 
| 77 |  | -        return "Image Mask step that resize the image to the target area while maintaining the aspect ratio." | 
| 78 |  | - | 
| 79 |  | -    @property | 
| 80 |  | -    def expected_components(self) -> List[ComponentSpec]: | 
| 81 |  | -        return [ | 
| 82 |  | -            ComponentSpec( | 
| 83 |  | -                "image_mask_processor", | 
| 84 |  | -                InpaintProcessor, | 
| 85 |  | -                config=FrozenDict({"vae_scale_factor": 16}), | 
| 86 |  | -                default_creation_method="from_config", | 
| 87 |  | -            ), | 
| 88 |  | -        ] | 
| 89 |  | - | 
| 90 |  | -    @property | 
| 91 |  | -    def inputs(self) -> List[InputParam]: | 
| 92 |  | -        return [ | 
| 93 |  | -            InputParam("image", required=True), | 
| 94 |  | -            InputParam("mask_image", required=True), | 
| 95 |  | -            InputParam("height"), | 
| 96 |  | -            InputParam("width"), | 
| 97 |  | -            InputParam("padding_mask_crop"), | 
| 98 |  | -        ] | 
| 99 |  | - | 
| 100 |  | -    @property | 
| 101 |  | -    def intermediate_outputs(self) -> List[OutputParam]: | 
| 102 |  | -        return [ | 
| 103 |  | -            OutputParam(name="original_image", type_hint=torch.Tensor, description="The original image"), | 
| 104 |  | -            OutputParam(name="original_mask", type_hint=torch.Tensor, description="The original mask"), | 
| 105 |  | -            OutputParam( | 
| 106 |  | -                name="crop_coords", | 
| 107 |  | -                type_hint=List[Tuple[int, int]], | 
| 108 |  | -                description="The crop coordinates to use for the preprocess/postprocess of the image and mask", | 
| 109 |  | -            ), | 
| 110 |  | -        ] | 
| 111 |  | - | 
| 112 |  | -    @torch.no_grad() | 
| 113 |  | -    def __call__(self, components: QwenImageModularPipeline, state: PipelineState): | 
| 114 |  | -        block_state = self.get_block_state(state) | 
| 115 |  | - | 
| 116 |  | -        block_state.height = block_state.height or components.default_height | 
| 117 |  | -        block_state.width = block_state.width or components.default_width | 
| 118 |  | - | 
| 119 |  | -        block_state.image, block_state.mask_image, postprocessing_kwargs = components.image_mask_processor.preprocess( | 
| 120 |  | -            image=block_state.image, | 
| 121 |  | -            mask=block_state.mask_image, | 
| 122 |  | -            height=block_state.height, | 
| 123 |  | -            width=block_state.width, | 
| 124 |  | -            padding_mask_crop=block_state.padding_mask_crop, | 
| 125 |  | -        ) | 
| 126 |  | - | 
| 127 |  | -        if postprocessing_kwargs: | 
| 128 |  | -            block_state.original_image = postprocessing_kwargs["original_image"] | 
| 129 |  | -            block_state.original_mask = postprocessing_kwargs["original_mask"] | 
| 130 |  | -            block_state.crop_coords = postprocessing_kwargs["crops_coords"] | 
| 131 |  | - | 
| 132 |  | -        self.set_block_state(state, block_state) | 
| 133 |  | -        return components, state | 
| 134 |  | - | 
| 135 |  | - | 
| 136 | 24 | class QwenImageInputsDynamicStep(ModularPipelineBlocks): | 
| 137 | 25 |     model_name = "qwenimage" | 
| 138 | 26 | 
 | 
| @@ -322,45 +210,3 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - | 
| 322 | 210 |         self.set_block_state(state, block_state) | 
| 323 | 211 | 
 | 
| 324 | 212 |         return components, state | 
| 325 |  | - | 
| 326 |  | - | 
| 327 |  | -class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks): | 
| 328 |  | -    model_name = "qwenimage" | 
| 329 |  | - | 
| 330 |  | -    @property | 
| 331 |  | -    def description(self) -> str: | 
| 332 |  | -        return "postprocess the generated image, optional apply the mask overally to the original image.." | 
| 333 |  | - | 
| 334 |  | -    @property | 
| 335 |  | -    def expected_components(self) -> List[ComponentSpec]: | 
| 336 |  | -        return [ | 
| 337 |  | -            ComponentSpec( | 
| 338 |  | -                "image_mask_processor", | 
| 339 |  | -                InpaintProcessor, | 
| 340 |  | -                config=FrozenDict({"vae_scale_factor": 16}), | 
| 341 |  | -                default_creation_method="from_config", | 
| 342 |  | -            ), | 
| 343 |  | -        ] | 
| 344 |  | - | 
| 345 |  | -    @property | 
| 346 |  | -    def inputs(self) -> List[InputParam]: | 
| 347 |  | -        return [ | 
| 348 |  | -            InputParam("images", required=True, description="the generated image from decoders step"), | 
| 349 |  | -            InputParam("original_image"), | 
| 350 |  | -            InputParam("original_mask"), | 
| 351 |  | -            InputParam("crop_coords"), | 
| 352 |  | -        ] | 
| 353 |  | - | 
| 354 |  | -    @torch.no_grad() | 
| 355 |  | -    def __call__(self, components: QwenImageModularPipeline, state: PipelineState): | 
| 356 |  | -        block_state = self.get_block_state(state) | 
| 357 |  | - | 
| 358 |  | -        block_state.images = components.image_mask_processor.postprocess( | 
| 359 |  | -            image=block_state.images, | 
| 360 |  | -            original_image=block_state.original_image, | 
| 361 |  | -            original_mask=block_state.original_mask, | 
| 362 |  | -            crops_coords=block_state.crop_coords, | 
| 363 |  | -        ) | 
| 364 |  | - | 
| 365 |  | -        self.set_block_state(state, block_state) | 
| 366 |  | -        return components, state | 
0 commit comments