Skip to content

Commit 9df6c2f

Browse files
committed
remove more.
1 parent 13cf2b0 commit 9df6c2f

File tree

1 file changed

+0
-90
lines changed

1 file changed

+0
-90
lines changed

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py

Lines changed: 0 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -193,54 +193,6 @@ def __init__(
193193
self.prompt_template_encode_start_idx = 34
194194
self.default_sample_size = 128
195195

196-
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
197-
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
198-
bool_mask = mask.bool()
199-
valid_lengths = bool_mask.sum(dim=1)
200-
selected = hidden_states[bool_mask]
201-
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
202-
203-
return split_result
204-
205-
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._get_qwen_prompt_embeds
206-
def _get_qwen_prompt_embeds(
207-
self,
208-
prompt: Union[str, List[str]] = None,
209-
device: Optional[torch.device] = None,
210-
dtype: Optional[torch.dtype] = None,
211-
):
212-
device = device or self._execution_device
213-
dtype = dtype or self.text_encoder.dtype
214-
215-
prompt = [prompt] if isinstance(prompt, str) else prompt
216-
217-
template = self.prompt_template_encode
218-
drop_idx = self.prompt_template_encode_start_idx
219-
txt = [template.format(e) for e in prompt]
220-
txt_tokens = self.tokenizer(
221-
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
222-
).to(device)
223-
encoder_hidden_states = self.text_encoder(
224-
input_ids=txt_tokens.input_ids,
225-
attention_mask=txt_tokens.attention_mask,
226-
output_hidden_states=True,
227-
)
228-
hidden_states = encoder_hidden_states.hidden_states[-1]
229-
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
230-
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
231-
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
232-
max_seq_len = max([e.size(0) for e in split_hidden_states])
233-
prompt_embeds = torch.stack(
234-
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
235-
)
236-
encoder_attention_mask = torch.stack(
237-
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
238-
)
239-
240-
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
241-
242-
return prompt_embeds, encoder_attention_mask
243-
244196
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_img2img.QwenImageImg2ImgPipeline._encode_vae_image
245197
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
246198
if isinstance(generator, list):
@@ -277,48 +229,6 @@ def get_timesteps(self, num_inference_steps, strength, device):
277229

278230
return timesteps, num_inference_steps - t_start
279231

280-
# Copied fromCopied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt
281-
def encode_prompt(
282-
self,
283-
prompt: Union[str, List[str]],
284-
device: Optional[torch.device] = None,
285-
num_images_per_prompt: int = 1,
286-
prompt_embeds: Optional[torch.Tensor] = None,
287-
prompt_embeds_mask: Optional[torch.Tensor] = None,
288-
max_sequence_length: int = 1024,
289-
):
290-
r"""
291-
292-
Args:
293-
prompt (`str` or `List[str]`, *optional*):
294-
prompt to be encoded
295-
device: (`torch.device`):
296-
torch device
297-
num_images_per_prompt (`int`):
298-
number of images that should be generated per prompt
299-
prompt_embeds (`torch.Tensor`, *optional*):
300-
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
301-
provided, text embeddings will be generated from `prompt` input argument.
302-
"""
303-
device = device or self._execution_device
304-
305-
prompt = [prompt] if isinstance(prompt, str) else prompt
306-
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
307-
308-
if prompt_embeds is None:
309-
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
310-
311-
prompt_embeds = prompt_embeds[:, :max_sequence_length]
312-
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
313-
314-
_, seq_len, _ = prompt_embeds.shape
315-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
316-
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
317-
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
318-
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
319-
320-
return prompt_embeds, prompt_embeds_mask
321-
322232
def check_inputs(
323233
self,
324234
prompt,

0 commit comments

Comments
 (0)