Skip to content

Commit 32032b3

Browse files
committed
fix code
1 parent 0c6ee77 commit 32032b3

File tree

2 files changed

+19
-25
lines changed

2 files changed

+19
-25
lines changed

src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,6 @@ def rewire_prompt(self, prompt, device):
291291
return rewrite_prompt
292292

293293
def _encode_prompt( self, prompt ):
294-
prompt = [prompt] if isinstance(prompt, str) else prompt
295-
batch_size = len(prompt)
296294
all_tokens = []
297295
for clean_prompt_sub, matched in split_quotation(prompt[0]):
298296
if matched:
@@ -341,23 +339,23 @@ def _encode_prompt( self, prompt ):
341339
prompt_embeds = text_output.hidden_states[-1].detach()
342340
prompt_embeds = prompt_embeds[:,self.prompt_template_encode_start_idx: -self.prompt_template_encode_end_idx ,:]
343341

344-
_, seq_len, _ = prompt_embeds.shape
345-
346-
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
347-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
348-
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
349-
350342
return prompt_embeds
351343

352344
def encode_prompt(self,
353345
prompt : List[str] = None,
354346
num_images_per_prompt: Optional[int] = 1,
355347
prompt_embeds: Optional[torch.Tensor] = None ):
356-
348+
prompt = [prompt] if isinstance(prompt, str) else prompt
349+
batch_size = len(prompt)
357350
# If prompt_embeds is provided and prompt is None, skip encoding
358351
if prompt_embeds is None:
359-
prompt_embeds = self._encode_prompt( prompt, num_images_per_prompt )
360-
352+
prompt_embeds = self._encode_prompt( prompt )
353+
354+
_, seq_len, _ = prompt_embeds.shape
355+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
356+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
357+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
358+
361359
text_ids = prepare_pos_ids(modality_id=0,
362360
type='text',
363361
start=(0, 0),

src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -279,14 +279,10 @@ def __init__(
279279
self.default_sample_size = 128
280280
self.tokenizer_max_length = 512
281281

282-
def _encode_prompt( self, prompt, image, num_images_per_prompt ):
283-
282+
def _encode_prompt( self, prompt, image ):
284283
raw_vl_input = self.image_processor_vl(images=image,return_tensors="pt")
285284
pixel_values = raw_vl_input['pixel_values']
286285
image_grid_thw = raw_vl_input['image_grid_thw']
287-
288-
prompt = [prompt] if isinstance(prompt, str) else prompt
289-
batch_size = len(prompt)
290286
all_tokens = []
291287
for clean_prompt_sub, matched in split_quotation(prompt[0]):
292288
if matched:
@@ -348,25 +344,25 @@ def _encode_prompt( self, prompt, image, num_images_per_prompt ):
348344
prompt_embeds = text_output.hidden_states[-1].detach()
349345
prompt_embeds = prompt_embeds[:,self.prompt_template_encode_start_idx: -self.prompt_template_encode_end_idx ,:]
350346

351-
_, seq_len, _ = prompt_embeds.shape
352-
353-
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
354-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
355-
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
356-
357347
return prompt_embeds
358348

359349
@torch.inference_mode()
360350
def encode_prompt(self,
361351
prompt : List[str] = None,
362352
image: Optional[torch.Tensor] = None,
363353
num_images_per_prompt: Optional[int] = 1,
364-
prompt_embeds: Optional[torch.Tensor] = None,):
365-
354+
prompt_embeds: Optional[torch.Tensor] = None):
355+
prompt = [prompt] if isinstance(prompt, str) else prompt
356+
batch_size = len(prompt)
366357
# If prompt_embeds is provided and prompt is None, skip encoding
367358
if prompt_embeds is None:
368-
prompt_embeds = self._encode_prompt( prompt, image, num_images_per_prompt )
359+
prompt_embeds = self._encode_prompt( prompt, image )
369360

361+
_, seq_len, _ = prompt_embeds.shape
362+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
363+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
364+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
365+
370366
text_ids = prepare_pos_ids(modality_id=0,
371367
type='text',
372368
start=(0, 0),

0 commit comments

Comments
 (0)