@@ -102,6 +102,74 @@ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor
102102
103103 return split_result
104104
105+ def _get_qwen_prompt_embeds (
106+ self ,
107+ prompt : Union [str , List [str ]],
108+ image : Optional [torch .Tensor ] = None ,
109+ device : Optional [torch .device ] = None ,
110+ dtype : Optional [torch .dtype ] = None ,
111+ ):
112+ device = device or self ._execution_device
113+ dtype = dtype or self .text_encoder .dtype
114+ prompt = [prompt ] if isinstance (prompt , str ) else prompt
115+
116+ template = self .prompt_template_encode
117+ drop_idx = self .prompt_template_encode_start_idx
118+ txt = [template .format (e ) for e in prompt ]
119+ use_multimodal = image is not None and hasattr (self , "processor" )
120+
121+ if use_multimodal :
122+ # --- Multimodal (text+image) ---
123+ model_inputs = self .processor (
124+ text = txt ,
125+ images = image ,
126+ padding = True ,
127+ return_tensors = "pt" ,
128+ ).to (device )
129+
130+ outputs = self .text_encoder (
131+ input_ids = model_inputs .input_ids ,
132+ attention_mask = model_inputs .attention_mask ,
133+ pixel_values = model_inputs .pixel_values ,
134+ image_grid_thw = model_inputs .image_grid_thw ,
135+ output_hidden_states = True ,
136+ )
137+ hidden_states = outputs .hidden_states [- 1 ]
138+ attn_mask = model_inputs .attention_mask
139+ else :
140+ # --- Text-only ---
141+ txt_tokens = self .tokenizer (
142+ txt ,
143+ max_length = self .tokenizer_max_length + drop_idx ,
144+ padding = True ,
145+ truncation = True ,
146+ return_tensors = "pt" ,
147+ ).to (device )
148+
149+ outputs = self .text_encoder (
150+ input_ids = txt_tokens .input_ids ,
151+ attention_mask = txt_tokens .attention_mask ,
152+ output_hidden_states = True ,
153+ )
154+ hidden_states = outputs .hidden_states [- 1 ]
155+ attn_mask = txt_tokens .attention_mask
156+
157+ split_hidden_states = self ._extract_masked_hidden (hidden_states , attn_mask )
158+ split_hidden_states = [e [drop_idx :] for e in split_hidden_states ]
159+
160+ attn_mask_list = [torch .ones (e .size (0 ), dtype = torch .long , device = e .device ) for e in split_hidden_states ]
161+ max_seq_len = max (e .size (0 ) for e in split_hidden_states )
162+
163+ prompt_embeds = torch .stack (
164+ [torch .cat ([u , u .new_zeros (max_seq_len - u .size (0 ), u .size (1 ))]) for u in split_hidden_states ]
165+ )
166+ encoder_attention_mask = torch .stack (
167+ [torch .cat ([u , u .new_zeros (max_seq_len - u .size (0 ))]) for u in attn_mask_list ]
168+ )
169+
170+ prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
171+ return prompt_embeds , encoder_attention_mask
172+
105173 @staticmethod
106174 def _pack_latents (latents , batch_size , num_channels_latents , height , width ):
107175 latents = latents .view (batch_size , num_channels_latents , height // 2 , 2 , width // 2 , 2 )
@@ -171,44 +239,6 @@ def encode_prompt(
171239
172240 return prompt_embeds , prompt_embeds_mask
173241
174- def _get_qwen_prompt_embeds (
175- self ,
176- prompt : Union [str , List [str ]] = None ,
177- device : Optional [torch .device ] = None ,
178- dtype : Optional [torch .dtype ] = None ,
179- ):
180- device = device or self ._execution_device
181- dtype = dtype or self .text_encoder .dtype
182-
183- prompt = [prompt ] if isinstance (prompt , str ) else prompt
184-
185- template = self .prompt_template_encode
186- drop_idx = self .prompt_template_encode_start_idx
187- txt = [template .format (e ) for e in prompt ]
188- txt_tokens = self .tokenizer (
189- txt , max_length = self .tokenizer_max_length + drop_idx , padding = True , truncation = True , return_tensors = "pt"
190- ).to (device )
191- encoder_hidden_states = self .text_encoder (
192- input_ids = txt_tokens .input_ids ,
193- attention_mask = txt_tokens .attention_mask ,
194- output_hidden_states = True ,
195- )
196- hidden_states = encoder_hidden_states .hidden_states [- 1 ]
197- split_hidden_states = self ._extract_masked_hidden (hidden_states , txt_tokens .attention_mask )
198- split_hidden_states = [e [drop_idx :] for e in split_hidden_states ]
199- attn_mask_list = [torch .ones (e .size (0 ), dtype = torch .long , device = e .device ) for e in split_hidden_states ]
200- max_seq_len = max ([e .size (0 ) for e in split_hidden_states ])
201- prompt_embeds = torch .stack (
202- [torch .cat ([u , u .new_zeros (max_seq_len - u .size (0 ), u .size (1 ))]) for u in split_hidden_states ]
203- )
204- encoder_attention_mask = torch .stack (
205- [torch .cat ([u , u .new_zeros (max_seq_len - u .size (0 ))]) for u in attn_mask_list ]
206- )
207-
208- prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
209-
210- return prompt_embeds , encoder_attention_mask
211-
212242
213243class QwenImageEditPipelineMixin (QwenImageMixin ):
214244 def encode_prompt (
@@ -252,53 +282,6 @@ def encode_prompt(
252282
253283 return prompt_embeds , prompt_embeds_mask
254284
255- def _get_qwen_prompt_embeds (
256- self ,
257- prompt : Union [str , List [str ]] = None ,
258- image : Optional [torch .Tensor ] = None ,
259- device : Optional [torch .device ] = None ,
260- dtype : Optional [torch .dtype ] = None ,
261- ):
262- device = device or self ._execution_device
263- dtype = dtype or self .text_encoder .dtype
264-
265- prompt = [prompt ] if isinstance (prompt , str ) else prompt
266-
267- template = self .prompt_template_encode
268- drop_idx = self .prompt_template_encode_start_idx
269- txt = [template .format (e ) for e in prompt ]
270-
271- model_inputs = self .processor (
272- text = txt ,
273- images = image ,
274- padding = True ,
275- return_tensors = "pt" ,
276- ).to (device )
277-
278- outputs = self .text_encoder (
279- input_ids = model_inputs .input_ids ,
280- attention_mask = model_inputs .attention_mask ,
281- pixel_values = model_inputs .pixel_values ,
282- image_grid_thw = model_inputs .image_grid_thw ,
283- output_hidden_states = True ,
284- )
285-
286- hidden_states = outputs .hidden_states [- 1 ]
287- split_hidden_states = self ._extract_masked_hidden (hidden_states , model_inputs .attention_mask )
288- split_hidden_states = [e [drop_idx :] for e in split_hidden_states ]
289- attn_mask_list = [torch .ones (e .size (0 ), dtype = torch .long , device = e .device ) for e in split_hidden_states ]
290- max_seq_len = max ([e .size (0 ) for e in split_hidden_states ])
291- prompt_embeds = torch .stack (
292- [torch .cat ([u , u .new_zeros (max_seq_len - u .size (0 ), u .size (1 ))]) for u in split_hidden_states ]
293- )
294- encoder_attention_mask = torch .stack (
295- [torch .cat ([u , u .new_zeros (max_seq_len - u .size (0 ))]) for u in attn_mask_list ]
296- )
297-
298- prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
299-
300- return prompt_embeds , encoder_attention_mask
301-
302285
303286def calculate_dimensions (target_area , ratio ):
304287 width = math .sqrt (target_area * ratio )
0 commit comments