@@ -193,54 +193,6 @@ def __init__(
193193 self .prompt_template_encode_start_idx = 34
194194 self .default_sample_size = 128
195195
196- # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
197- def _extract_masked_hidden (self , hidden_states : torch .Tensor , mask : torch .Tensor ):
198- bool_mask = mask .bool ()
199- valid_lengths = bool_mask .sum (dim = 1 )
200- selected = hidden_states [bool_mask ]
201- split_result = torch .split (selected , valid_lengths .tolist (), dim = 0 )
202-
203- return split_result
204-
205- # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._get_qwen_prompt_embeds
206- def _get_qwen_prompt_embeds (
207- self ,
208- prompt : Union [str , List [str ]] = None ,
209- device : Optional [torch .device ] = None ,
210- dtype : Optional [torch .dtype ] = None ,
211- ):
212- device = device or self ._execution_device
213- dtype = dtype or self .text_encoder .dtype
214-
215- prompt = [prompt ] if isinstance (prompt , str ) else prompt
216-
217- template = self .prompt_template_encode
218- drop_idx = self .prompt_template_encode_start_idx
219- txt = [template .format (e ) for e in prompt ]
220- txt_tokens = self .tokenizer (
221- txt , max_length = self .tokenizer_max_length + drop_idx , padding = True , truncation = True , return_tensors = "pt"
222- ).to (device )
223- encoder_hidden_states = self .text_encoder (
224- input_ids = txt_tokens .input_ids ,
225- attention_mask = txt_tokens .attention_mask ,
226- output_hidden_states = True ,
227- )
228- hidden_states = encoder_hidden_states .hidden_states [- 1 ]
229- split_hidden_states = self ._extract_masked_hidden (hidden_states , txt_tokens .attention_mask )
230- split_hidden_states = [e [drop_idx :] for e in split_hidden_states ]
231- attn_mask_list = [torch .ones (e .size (0 ), dtype = torch .long , device = e .device ) for e in split_hidden_states ]
232- max_seq_len = max ([e .size (0 ) for e in split_hidden_states ])
233- prompt_embeds = torch .stack (
234- [torch .cat ([u , u .new_zeros (max_seq_len - u .size (0 ), u .size (1 ))]) for u in split_hidden_states ]
235- )
236- encoder_attention_mask = torch .stack (
237- [torch .cat ([u , u .new_zeros (max_seq_len - u .size (0 ))]) for u in attn_mask_list ]
238- )
239-
240- prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
241-
242- return prompt_embeds , encoder_attention_mask
243-
244196 # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_img2img.QwenImageImg2ImgPipeline._encode_vae_image
245197 def _encode_vae_image (self , image : torch .Tensor , generator : torch .Generator ):
246198 if isinstance (generator , list ):
@@ -277,48 +229,6 @@ def get_timesteps(self, num_inference_steps, strength, device):
277229
278230 return timesteps , num_inference_steps - t_start
279231
280- # Copied fromCopied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt
281- def encode_prompt (
282- self ,
283- prompt : Union [str , List [str ]],
284- device : Optional [torch .device ] = None ,
285- num_images_per_prompt : int = 1 ,
286- prompt_embeds : Optional [torch .Tensor ] = None ,
287- prompt_embeds_mask : Optional [torch .Tensor ] = None ,
288- max_sequence_length : int = 1024 ,
289- ):
290- r"""
291-
292- Args:
293- prompt (`str` or `List[str]`, *optional*):
294- prompt to be encoded
295- device: (`torch.device`):
296- torch device
297- num_images_per_prompt (`int`):
298- number of images that should be generated per prompt
299- prompt_embeds (`torch.Tensor`, *optional*):
300- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
301- provided, text embeddings will be generated from `prompt` input argument.
302- """
303- device = device or self ._execution_device
304-
305- prompt = [prompt ] if isinstance (prompt , str ) else prompt
306- batch_size = len (prompt ) if prompt_embeds is None else prompt_embeds .shape [0 ]
307-
308- if prompt_embeds is None :
309- prompt_embeds , prompt_embeds_mask = self ._get_qwen_prompt_embeds (prompt , device )
310-
311- prompt_embeds = prompt_embeds [:, :max_sequence_length ]
312- prompt_embeds_mask = prompt_embeds_mask [:, :max_sequence_length ]
313-
314- _ , seq_len , _ = prompt_embeds .shape
315- prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
316- prompt_embeds = prompt_embeds .view (batch_size * num_images_per_prompt , seq_len , - 1 )
317- prompt_embeds_mask = prompt_embeds_mask .repeat (1 , num_images_per_prompt , 1 )
318- prompt_embeds_mask = prompt_embeds_mask .view (batch_size * num_images_per_prompt , seq_len )
319-
320- return prompt_embeds , prompt_embeds_mask
321-
322232 def check_inputs (
323233 self ,
324234 prompt ,
0 commit comments