Skip to content

Commit 856231c

Browse files
committed
fix tests
1 parent fce2f9e commit 856231c

File tree

4 files changed

+37
-21
lines changed

4 files changed

+37
-21
lines changed

src/diffusers/pipelines/easyanimate/pipeline_easyanimate.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,10 @@ def __init__(
240240

241241
def encode_prompt(
242242
self,
243-
prompt: str,
243+
prompt: Union[str, List[str]],
244244
num_images_per_prompt: int = 1,
245245
do_classifier_free_guidance: bool = True,
246-
negative_prompt: Optional[str] = None,
246+
negative_prompt: Optional[Union[str, List[str]]] = None,
247247
prompt_embeds: Optional[torch.Tensor] = None,
248248
negative_prompt_embeds: Optional[torch.Tensor] = None,
249249
prompt_attention_mask: Optional[torch.Tensor] = None,
@@ -294,7 +294,7 @@ def encode_prompt(
294294
batch_size = prompt_embeds.shape[0]
295295

296296
if prompt_embeds is None:
297-
if prompt is not None and isinstance(prompt, str):
297+
if isinstance(prompt, str):
298298
messages = [
299299
{
300300
"role": "user",
@@ -309,10 +309,12 @@ def encode_prompt(
309309
}
310310
for _prompt in prompt
311311
]
312-
text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
312+
text = [
313+
self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
314+
]
313315

314316
text_inputs = self.tokenizer(
315-
text=[text],
317+
text=text,
316318
padding="max_length",
317319
max_length=max_sequence_length,
318320
truncation=True,
@@ -358,10 +360,12 @@ def encode_prompt(
358360
}
359361
for _negative_prompt in negative_prompt
360362
]
361-
text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
363+
text = [
364+
self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
365+
]
362366

363367
text_inputs = self.tokenizer(
364-
text=[text],
368+
text=text,
365369
padding="max_length",
366370
max_length=max_sequence_length,
367371
truncation=True,

src/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -351,10 +351,10 @@ def __init__(
351351
# Copied from diffusers.pipelines.easyanimate.pipeline_easyanimate.EasyAnimatePipeline.encode_prompt
352352
def encode_prompt(
353353
self,
354-
prompt: str,
354+
prompt: Union[str, List[str]],
355355
num_images_per_prompt: int = 1,
356356
do_classifier_free_guidance: bool = True,
357-
negative_prompt: Optional[str] = None,
357+
negative_prompt: Optional[Union[str, List[str]]] = None,
358358
prompt_embeds: Optional[torch.Tensor] = None,
359359
negative_prompt_embeds: Optional[torch.Tensor] = None,
360360
prompt_attention_mask: Optional[torch.Tensor] = None,
@@ -405,7 +405,7 @@ def encode_prompt(
405405
batch_size = prompt_embeds.shape[0]
406406

407407
if prompt_embeds is None:
408-
if prompt is not None and isinstance(prompt, str):
408+
if isinstance(prompt, str):
409409
messages = [
410410
{
411411
"role": "user",
@@ -420,10 +420,12 @@ def encode_prompt(
420420
}
421421
for _prompt in prompt
422422
]
423-
text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
423+
text = [
424+
self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
425+
]
424426

425427
text_inputs = self.tokenizer(
426-
text=[text],
428+
text=text,
427429
padding="max_length",
428430
max_length=max_sequence_length,
429431
truncation=True,
@@ -469,10 +471,12 @@ def encode_prompt(
469471
}
470472
for _negative_prompt in negative_prompt
471473
]
472-
text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
474+
text = [
475+
self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
476+
]
473477

474478
text_inputs = self.tokenizer(
475-
text=[text],
479+
text=text,
476480
padding="max_length",
477481
max_length=max_sequence_length,
478482
truncation=True,

src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -393,10 +393,10 @@ def __init__(
393393
# Copied from diffusers.pipelines.easyanimate.pipeline_easyanimate.EasyAnimatePipeline.encode_prompt
394394
def encode_prompt(
395395
self,
396-
prompt: str,
396+
prompt: Union[str, List[str]],
397397
num_images_per_prompt: int = 1,
398398
do_classifier_free_guidance: bool = True,
399-
negative_prompt: Optional[str] = None,
399+
negative_prompt: Optional[Union[str, List[str]]] = None,
400400
prompt_embeds: Optional[torch.Tensor] = None,
401401
negative_prompt_embeds: Optional[torch.Tensor] = None,
402402
prompt_attention_mask: Optional[torch.Tensor] = None,
@@ -447,7 +447,7 @@ def encode_prompt(
447447
batch_size = prompt_embeds.shape[0]
448448

449449
if prompt_embeds is None:
450-
if prompt is not None and isinstance(prompt, str):
450+
if isinstance(prompt, str):
451451
messages = [
452452
{
453453
"role": "user",
@@ -462,10 +462,12 @@ def encode_prompt(
462462
}
463463
for _prompt in prompt
464464
]
465-
text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
465+
text = [
466+
self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
467+
]
466468

467469
text_inputs = self.tokenizer(
468-
text=[text],
470+
text=text,
469471
padding="max_length",
470472
max_length=max_sequence_length,
471473
truncation=True,
@@ -511,10 +513,12 @@ def encode_prompt(
511513
}
512514
for _negative_prompt in negative_prompt
513515
]
514-
text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
516+
text = [
517+
self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
518+
]
515519

516520
text_inputs = self.tokenizer(
517-
text=[text],
521+
text=text,
518522
padding="max_length",
519523
max_length=max_sequence_length,
520524
truncation=True,

tests/pipelines/easyanimate/test_easyanimate.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,10 @@ def test_dict_tuple_outputs_equivalent(self, expected_slice=None, expected_max_d
250250
# Seems to need a higher tolerance
251251
return super().test_dict_tuple_outputs_equivalent(expected_slice, expected_max_difference)
252252

253+
def test_encode_prompt_works_in_isolation(self):
254+
# Seems to need a higher tolerance
255+
return super().test_encode_prompt_works_in_isolation(atol=1e-3, rtol=1e-3)
256+
253257

254258
@slow
255259
@require_torch_gpu

0 commit comments

Comments
 (0)