2525
2626from  .modular_pipeline  import  QwenImageModularPipeline 
2727
28+ from  ...pipelines .qwenimage .pipeline_qwenimage  import  calculate_dimensions 
29+ 
2830logger  =  logging .get_logger (__name__ )
2931
3032
33+ def  _extract_masked_hidden (hidden_states : torch .Tensor , mask : torch .Tensor ):
34+     bool_mask  =  mask .bool ()
35+     valid_lengths  =  bool_mask .sum (dim = 1 )
36+     selected  =  hidden_states [bool_mask ]
37+     split_result  =  torch .split (selected , valid_lengths .tolist (), dim = 0 )
38+     return  split_result 
39+ 
3140def  get_qwen_prompt_embeds (
3241    text_encoder ,
3342    tokenizer ,
@@ -53,13 +62,6 @@ def get_qwen_prompt_embeds(
5362    )
5463    hidden_states  =  encoder_hidden_states .hidden_states [- 1 ]
5564
56-     def  _extract_masked_hidden (hidden_states : torch .Tensor , mask : torch .Tensor ):
57-         bool_mask  =  mask .bool ()
58-         valid_lengths  =  bool_mask .sum (dim = 1 )
59-         selected  =  hidden_states [bool_mask ]
60-         split_result  =  torch .split (selected , valid_lengths .tolist (), dim = 0 )
61-         return  split_result 
62- 
6365    split_hidden_states  =  _extract_masked_hidden (hidden_states , txt_tokens .attention_mask )
6466    split_hidden_states  =  [e [drop_idx :] for  e  in  split_hidden_states ]
6567    attn_mask_list  =  [torch .ones (e .size (0 ), dtype = torch .long , device = e .device ) for  e  in  split_hidden_states ]
@@ -75,6 +77,55 @@ def _extract_masked_hidden(hidden_states: torch.Tensor, mask: torch.Tensor):
7577
7678    return  prompt_embeds , encoder_attention_mask 
7779
80+ 
81+ def  get_qwen_prompt_embeds_edit (
82+     text_encoder ,
83+     processor ,
84+     prompt : Union [str , List [str ]] =  None ,
85+     image : Optional [torch .Tensor ] =  None ,
86+     prompt_template_encode : str  =  "<|im_start|>system\n Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n <|im_start|>user\n <|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n <|im_start|>assistant\n " ,
87+     prompt_template_encode_start_idx : int  =  64 ,
88+     device : Optional [torch .device ] =  None ,
89+     dtype : Optional [torch .dtype ] =  None ,
90+ ):
91+ 
92+     prompt  =  [prompt ] if  isinstance (prompt , str ) else  prompt 
93+ 
94+     template  =  prompt_template_encode 
95+     drop_idx  =  prompt_template_encode_start_idx 
96+     txt  =  [template .format (e ) for  e  in  prompt ]
97+ 
98+     model_inputs  =  processor (
99+         text = txt ,
100+         images = image ,
101+         padding = True ,
102+         return_tensors = "pt" ,
103+     ).to (device )
104+ 
105+     outputs  =  text_encoder (
106+         input_ids = model_inputs .input_ids ,
107+         attention_mask = model_inputs .attention_mask ,
108+         pixel_values = model_inputs .pixel_values ,
109+         image_grid_thw = model_inputs .image_grid_thw ,
110+         output_hidden_states = True ,
111+     )
112+ 
113+     hidden_states  =  outputs .hidden_states [- 1 ]
114+     split_hidden_states  =  _extract_masked_hidden (hidden_states , model_inputs .attention_mask )
115+     split_hidden_states  =  [e [drop_idx :] for  e  in  split_hidden_states ]
116+     attn_mask_list  =  [torch .ones (e .size (0 ), dtype = torch .long , device = e .device ) for  e  in  split_hidden_states ]
117+     max_seq_len  =  max ([e .size (0 ) for  e  in  split_hidden_states ])
118+     prompt_embeds  =  torch .stack (
119+         [torch .cat ([u , u .new_zeros (max_seq_len  -  u .size (0 ), u .size (1 ))]) for  u  in  split_hidden_states ]
120+     )
121+     encoder_attention_mask  =  torch .stack (
122+         [torch .cat ([u , u .new_zeros (max_seq_len  -  u .size (0 ))]) for  u  in  attn_mask_list ]
123+     )
124+ 
125+     prompt_embeds  =  prompt_embeds .to (dtype = dtype , device = device )
126+ 
127+     return  prompt_embeds , encoder_attention_mask 
128+ 
78129class  QwenImageTextEncoderStep (ModularPipelineBlocks ):
79130    model_name  =  "qwenimage" 
80131
@@ -139,6 +190,137 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
139190        device  =  components ._execution_device 
140191        self .check_inputs (block_state .prompt , block_state .negative_prompt , block_state .max_sequence_length )
141192
193+         block_state .prompt_embeds , block_state .prompt_embeds_mask  =  get_qwen_prompt_embeds (
194+             components .text_encoder ,
195+             components .tokenizer ,
196+             prompt = block_state .prompt ,
197+             prompt_template_encode = components .config .prompt_template_encode ,
198+             prompt_template_encode_start_idx = components .config .prompt_template_encode_start_idx ,
199+             tokenizer_max_length = components .config .tokenizer_max_length ,
200+             device = device ,
201+         )
202+ 
203+         block_state .prompt_embeds  =  block_state .prompt_embeds [:, :block_state .max_sequence_length ]
204+         block_state .prompt_embeds_mask  =  block_state .prompt_embeds_mask [:, :block_state .max_sequence_length ]
205+ 
206+         if  components .requires_unconditional_embeds :
207+             block_state .negative_prompt_embeds , block_state .negative_prompt_embeds_mask  =  get_qwen_prompt_embeds (
208+                 components .text_encoder ,
209+                 components .tokenizer ,
210+                 prompt = block_state .negative_prompt ,
211+                 prompt_template_encode = components .config .prompt_template_encode ,
212+                 prompt_template_encode_start_idx = components .config .prompt_template_encode_start_idx ,
213+                 tokenizer_max_length = components .config .tokenizer_max_length ,
214+                 device = device ,
215+             )
216+             block_state .negative_prompt_embeds  =  block_state .negative_prompt_embeds [:, :block_state .max_sequence_length ]
217+             block_state .negative_prompt_embeds_mask  =  block_state .negative_prompt_embeds_mask [:, :block_state .max_sequence_length ]
218+ 
219+         self .set_block_state (state , block_state )
220+         return  components , state 
221+ 
222+ 
223+ class  QwenImageImageResizeStep (ModularPipelineBlocks ):
224+     model_name  =  "qwenimage" 
225+ 
226+     @property  
227+     def  description (self ) ->  str :
228+         return  "Image Resize step that resize the image to the target area while maintaining the aspect ratio" 
229+     
230+     @property  
231+     def  expected_components (self ) ->  List [ComponentSpec ]:
232+         return  [
233+             ComponentSpec ("image_processor" , VaeImageProcessor , config = FrozenDict ({"vae_scale_factor" : 16 }), default_creation_method = "from_config" ),
234+         ]
235+     
236+     @property  
237+     def  inputs (self ) ->  List [InputParam ]:
238+         return  [
239+             InputParam (name = "image" , required = True , type_hint = torch .Tensor , description = "The image to resize" ),
240+         ]
241+     
242+     @torch .no_grad () 
243+     def  __call__ (self , components : QwenImageModularPipeline , state : PipelineState ):
244+         block_state  =  self .get_block_state (state )
245+ 
246+ 
247+         if  not  isinstance (block_state .image , list ):
248+             block_state .image  =  [block_state .image ]
249+         
250+         image_width , image_height  =  block_state .image [0 ].size 
251+         calculated_width , calculated_height  =  calculate_dimensions (1024  *  1024 , image_width  /  image_height )
252+ 
253+         block_state .image  =  components .image_processor .resize (image , (calculated_height , calculated_width ))
254+         self .set_block_state (state , block_state )
255+         return  components , state 
256+ 
257+ 
258+ class  QwenImageEditTextEncoderStep (ModularPipelineBlocks ):
259+     model_name  =  "qwenimage" 
260+ 
261+     @property  
262+     def  description (self ) ->  str :
263+         return  "Text Encoder step that generate text_embeddings to guide the image generation" 
264+     
265+     @property  
266+     def  expected_components (self ) ->  List [ComponentSpec ]:
267+         return  [
268+             ComponentSpec ("text_encoder" , Qwen2_5_VLForConditionalGeneration ),
269+             ComponentSpec ("processor" , Qwen2VLProcessor ),
270+             ComponentSpec (
271+                 "guider" ,
272+                 ClassifierFreeGuidance ,
273+                 config = FrozenDict ({"guidance_scale" : 4.0 }),
274+                 default_creation_method = "from_config" ,
275+             ),
276+         ]
277+     
278+     @property  
279+     def  expected_configs (self ) ->  List [ConfigSpec ]:
280+         return  [
281+             ConfigSpec (name = "prompt_template_encode" , default = "<|im_start|>system\n Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n <|im_start|>user\n <|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n <|im_start|>assistant\n " ),
282+             ConfigSpec (name = "prompt_template_encode_start_idx" , default = 64 ),
283+         ]
284+     
285+     @property  
286+     def  inputs (self ) ->  List [InputParam ]:
287+         return  [
288+             InputParam (name = "prompt" , required = True , type_hint = str , description = "The prompt to encode" ),
289+             InputParam (name = "negative_prompt" , type_hint = str , description = "The negative prompt to encode" ),
290+             InputParam (name = "max_sequence_length" , type_hint = int , description = "The max sequence length to use" , default = 1024 ),
291+         ]
292+     
293+     @property  
294+     def  intermediate_outputs (self ) ->  List [OutputParam ]:
295+         return  [
296+             OutputParam (name = "prompt_embeds" , kwargs_type = "guider_input_fields" ,type_hint = torch .Tensor , description = "The prompt embeddings" ),
297+             OutputParam (name = "prompt_embeds_mask" , kwargs_type = "guider_input_fields" , type_hint = torch .Tensor , description = "The encoder attention mask" ),
298+             OutputParam (name = "negative_prompt_embeds" , kwargs_type = "guider_input_fields" , type_hint = torch .Tensor , description = "The negative prompt embeddings" ),
299+             OutputParam (name = "negative_prompt_embeds_mask" , kwargs_type = "guider_input_fields" , type_hint = torch .Tensor , description = "The negative prompt embeddings mask" ),
300+         ]
301+ 
302+     @staticmethod  
303+     def  check_inputs (prompt , negative_prompt , max_sequence_length ):
304+ 
305+         if  not  isinstance (prompt , str ) and  not  isinstance (prompt , list ):
306+             raise  ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )}  )
307+         
308+         if  negative_prompt  is  not None  and  not  isinstance (negative_prompt , str ) and  not  isinstance (negative_prompt , list ):
309+             raise  ValueError (f"`negative_prompt` has to be of type `str` or `list` but is { type (negative_prompt )}  )
310+         
311+         if  max_sequence_length  is  not None  and  max_sequence_length  >  1024 :
312+             raise  ValueError (f"`max_sequence_length` cannot be greater than 1024 but is { max_sequence_length }  )
313+     
314+     @torch .no_grad () 
315+     def  __call__ (self , components : QwenImageModularPipeline , state : PipelineState ):
316+         block_state  =  self .get_block_state (state )
317+ 
318+         self .check_inputs (block_state .prompt , block_state .negative_prompt , block_state .max_sequence_length )
319+ 
320+         device  =  components ._execution_device 
321+         image  =  components .image_processor .preprocess (block_state .image )
322+ 
323+ 
142324        block_state .prompt_embeds , block_state .prompt_embeds_mask  =  get_qwen_prompt_embeds (
143325            components .text_encoder ,
144326            components .tokenizer ,
0 commit comments