@@ -128,6 +128,61 @@ def get_qwen_prompt_embeds_edit(
128128    return  prompt_embeds , encoder_attention_mask 
129129
130130
131+ def  get_qwen_prompt_embeds_edit_plus (
132+     text_encoder ,
133+     processor ,
134+     prompt : Union [str , List [str ]] =  None ,
135+     image : Optional [Union [torch .Tensor , List [PIL .Image .Image ], PIL .Image .Image ]] =  None ,
136+     prompt_template_encode : str  =  "<|im_start|>system\n Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n <|im_start|>user\n {}<|im_end|>\n <|im_start|>assistant\n " ,
137+     img_template_encode : str  =  "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" ,
138+     prompt_template_encode_start_idx : int  =  64 ,
139+     device : Optional [torch .device ] =  None ,
140+ ):
141+     prompt  =  [prompt ] if  isinstance (prompt , str ) else  prompt 
142+     if  isinstance (image , list ):
143+         base_img_prompt  =  "" 
144+         for  i , img  in  enumerate (image ):
145+             base_img_prompt  +=  img_template_encode .format (i  +  1 )
146+     elif  image  is  not   None :
147+         base_img_prompt  =  img_template_encode .format (1 )
148+     else :
149+         base_img_prompt  =  "" 
150+ 
151+     template  =  prompt_template_encode 
152+ 
153+     drop_idx  =  prompt_template_encode_start_idx 
154+     txt  =  [template .format (base_img_prompt  +  e ) for  e  in  prompt ]
155+ 
156+     model_inputs  =  processor (
157+         text = txt ,
158+         images = image ,
159+         padding = True ,
160+         return_tensors = "pt" ,
161+     ).to (device )
162+     outputs  =  text_encoder (
163+         input_ids = model_inputs .input_ids ,
164+         attention_mask = model_inputs .attention_mask ,
165+         pixel_values = model_inputs .pixel_values ,
166+         image_grid_thw = model_inputs .image_grid_thw ,
167+         output_hidden_states = True ,
168+     )
169+ 
170+     hidden_states  =  outputs .hidden_states [- 1 ]
171+     split_hidden_states  =  _extract_masked_hidden (hidden_states , model_inputs .attention_mask )
172+     split_hidden_states  =  [e [drop_idx :] for  e  in  split_hidden_states ]
173+     attn_mask_list  =  [torch .ones (e .size (0 ), dtype = torch .long , device = e .device ) for  e  in  split_hidden_states ]
174+     max_seq_len  =  max ([e .size (0 ) for  e  in  split_hidden_states ])
175+     prompt_embeds  =  torch .stack (
176+         [torch .cat ([u , u .new_zeros (max_seq_len  -  u .size (0 ), u .size (1 ))]) for  u  in  split_hidden_states ]
177+     )
178+     encoder_attention_mask  =  torch .stack (
179+         [torch .cat ([u , u .new_zeros (max_seq_len  -  u .size (0 ))]) for  u  in  attn_mask_list ]
180+     )
181+ 
182+     prompt_embeds  =  prompt_embeds .to (device = device )
183+     return  prompt_embeds , encoder_attention_mask 
184+ 
185+ 
131186# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents 
132187def  retrieve_latents (
133188    encoder_output : torch .Tensor , generator : Optional [torch .Generator ] =  None , sample_mode : str  =  "sample" 
@@ -266,6 +321,83 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
266321        return  components , state 
267322
268323
324+ class  QwenImageEditPlusResizeDynamicStep (QwenImageEditResizeDynamicStep ):
325+     model_name  =  "qwenimage" 
326+ 
327+     def  __init__ (
328+         self ,
329+         input_name : str  =  "image" ,
330+         output_name : str  =  "resized_image" ,
331+         vae_image_output_name : str  =  "vae_image" ,
332+     ):
333+         """Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio. 
334+ 
335+         This block resizes an input image or a list input images and exposes the resized result under configurable 
336+         input and output names. Use this when you need to wire the resize step to different image fields (e.g., 
337+         "image", "control_image") 
338+ 
339+         Args: 
340+             input_name (str, optional): Name of the image field to read from the 
341+                 pipeline state. Defaults to "image". 
342+             output_name (str, optional): Name of the resized image field to write 
343+                 back to the pipeline state. Defaults to "resized_image". 
344+             vae_image_output_name (str, optional): Name of the image field 
345+                 to write back to the pipeline state. This is used by the VAE encoder step later on. QwenImage Edit Plus 
346+                 processes the input image(s) differently for the VL and the VAE. 
347+         """ 
348+         if  not  isinstance (input_name , str ) or  not  isinstance (output_name , str ):
349+             raise  ValueError (
350+                 f"input_name and output_name must be strings but are { type (input_name )}   and { type (output_name )}  " 
351+             )
352+         self .condition_image_size  =  384  *  384 
353+         self ._image_input_name  =  input_name 
354+         self ._resized_image_output_name  =  output_name 
355+         self ._vae_image_output_name  =  vae_image_output_name 
356+         super ().__init__ ()
357+ 
358+     @property  
359+     def  intermediate_outputs (self ) ->  List [OutputParam ]:
360+         return  super ().intermediate_outputs  +  [
361+             OutputParam (
362+                 name = self ._vae_image_output_name ,
363+                 type_hint = List [PIL .Image .Image ],
364+                 description = "The images to be processed which will be further used by the VAE encoder." ,
365+             ),
366+         ]
367+ 
368+     @torch .no_grad () 
369+     def  __call__ (self , components : QwenImageModularPipeline , state : PipelineState ):
370+         block_state  =  self .get_block_state (state )
371+ 
372+         images  =  getattr (block_state , self ._image_input_name )
373+ 
374+         if  not  is_valid_image_imagelist (images ):
375+             raise  ValueError (f"Images must be image or list of images but are { type (images )}  " )
376+ 
377+         if  (
378+             not  isinstance (images , torch .Tensor )
379+             and  isinstance (images , PIL .Image .Image )
380+             and  not  isinstance (images , list )
381+         ):
382+             images  =  [images ]
383+ 
384+         # TODO (sayakpaul): revisit this when the inputs are `torch.Tensor`s 
385+         condition_images  =  []
386+         vae_images  =  []
387+         for  img  in  images :
388+             image_width , image_height  =  img .size 
389+             condition_width , condition_height , _  =  calculate_dimensions (
390+                 self .condition_image_size , image_width  /  image_height 
391+             )
392+             condition_images .append (components .image_resize_processor .resize (img , condition_height , condition_width ))
393+             vae_images .append (img )
394+ 
395+         setattr (block_state , self ._resized_image_output_name , condition_images )
396+         setattr (block_state , self ._vae_image_output_name , vae_images )
397+         self .set_block_state (state , block_state )
398+         return  components , state 
399+ 
400+ 
269401class  QwenImageTextEncoderStep (ModularPipelineBlocks ):
270402    model_name  =  "qwenimage" 
271403
@@ -511,6 +643,61 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
511643        return  components , state 
512644
513645
646+ class  QwenImageEditPlusTextEncoderStep (QwenImageEditTextEncoderStep ):
647+     model_name  =  "qwenimage" 
648+ 
649+     @property  
650+     def  expected_configs (self ) ->  List [ConfigSpec ]:
651+         return  [
652+             ConfigSpec (
653+                 name = "prompt_template_encode" ,
654+                 default = "<|im_start|>system\n Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n <|im_start|>user\n {}<|im_end|>\n <|im_start|>assistant\n " ,
655+             ),
656+             ConfigSpec (
657+                 name = "img_template_encode" ,
658+                 default = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" ,
659+             ),
660+             ConfigSpec (name = "prompt_template_encode_start_idx" , default = 64 ),
661+         ]
662+ 
663+     @torch .no_grad () 
664+     def  __call__ (self , components : QwenImageModularPipeline , state : PipelineState ):
665+         block_state  =  self .get_block_state (state )
666+ 
667+         self .check_inputs (block_state .prompt , block_state .negative_prompt )
668+ 
669+         device  =  components ._execution_device 
670+ 
671+         block_state .prompt_embeds , block_state .prompt_embeds_mask  =  get_qwen_prompt_embeds_edit_plus (
672+             components .text_encoder ,
673+             components .processor ,
674+             prompt = block_state .prompt ,
675+             image = block_state .resized_image ,
676+             prompt_template_encode = components .config .prompt_template_encode ,
677+             img_template_encode = components .config .img_template_encode ,
678+             prompt_template_encode_start_idx = components .config .prompt_template_encode_start_idx ,
679+             device = device ,
680+         )
681+ 
682+         if  components .requires_unconditional_embeds :
683+             negative_prompt  =  block_state .negative_prompt  or  " " 
684+             block_state .negative_prompt_embeds , block_state .negative_prompt_embeds_mask  =  (
685+                 get_qwen_prompt_embeds_edit_plus (
686+                     components .text_encoder ,
687+                     components .processor ,
688+                     prompt = negative_prompt ,
689+                     image = block_state .resized_image ,
690+                     prompt_template_encode = components .config .prompt_template_encode ,
691+                     img_template_encode = components .config .img_template_encode ,
692+                     prompt_template_encode_start_idx = components .config .prompt_template_encode_start_idx ,
693+                     device = device ,
694+                 )
695+             )
696+ 
697+         self .set_block_state (state , block_state )
698+         return  components , state 
699+ 
700+ 
514701class  QwenImageInpaintProcessImagesInputStep (ModularPipelineBlocks ):
515702    model_name  =  "qwenimage" 
516703
@@ -612,12 +799,7 @@ def expected_components(self) -> List[ComponentSpec]:
612799
613800    @property  
614801    def  inputs (self ) ->  List [InputParam ]:
615-         return  [
616-             InputParam ("resized_image" ),
617-             InputParam ("image" ),
618-             InputParam ("height" ),
619-             InputParam ("width" ),
620-         ]
802+         return  [InputParam ("resized_image" ), InputParam ("image" ), InputParam ("height" ), InputParam ("width" )]
621803
622804    @property  
623805    def  intermediate_outputs (self ) ->  List [OutputParam ]:
@@ -661,6 +843,47 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
661843        return  components , state 
662844
663845
846+ class  QwenImageEditPlusProcessImagesInputStep (QwenImageProcessImagesInputStep ):
847+     model_name  =  "qwenimage-edit-plus" 
848+     vae_image_size  =  1024  *  1024 
849+ 
850+     @property  
851+     def  description (self ) ->  str :
852+         return  "Image Preprocess step for QwenImage Edit Plus. Unlike QwenImage Edit, QwenImage Edit Plus doesn't use the same resized image for further preprocessing." 
853+ 
854+     @property  
855+     def  inputs (self ) ->  List [InputParam ]:
856+         return  [InputParam ("vae_image" ), InputParam ("image" ), InputParam ("height" ), InputParam ("width" )]
857+ 
858+     @torch .no_grad () 
859+     def  __call__ (self , components : QwenImageModularPipeline , state : PipelineState ):
860+         block_state  =  self .get_block_state (state )
861+ 
862+         if  block_state .vae_image  is  None  and  block_state .image  is  None :
863+             raise  ValueError ("`vae_image` and `image` cannot be None at the same time" )
864+ 
865+         if  block_state .vae_image  is  None :
866+             image  =  block_state .image 
867+             self .check_inputs (
868+                 height = block_state .height , width = block_state .width , vae_scale_factor = components .vae_scale_factor 
869+             )
870+             height  =  block_state .height  or  components .default_height 
871+             width  =  block_state .width  or  components .default_width 
872+             block_state .processed_image  =  components .image_processor .preprocess (
873+                 image = image , height = height , width = width 
874+             )
875+         else :
876+             width , height  =  block_state .vae_image [0 ].size 
877+             image  =  block_state .vae_image 
878+ 
879+             block_state .processed_image  =  components .image_processor .preprocess (
880+                 image = image , height = height , width = width 
881+             )
882+ 
883+         self .set_block_state (state , block_state )
884+         return  components , state 
885+ 
886+ 
664887class  QwenImageVaeEncoderDynamicStep (ModularPipelineBlocks ):
665888    model_name  =  "qwenimage" 
666889
@@ -738,7 +961,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
738961            dtype = dtype ,
739962            latent_channels = components .num_channels_latents ,
740963        )
741- 
742964        setattr (block_state , self ._image_latents_output_name , image_latents )
743965
744966        self .set_block_state (state , block_state )
0 commit comments