|
45 | 45 | >>> # pipe.transformer.set_attention_backend("flash") |
46 | 46 | >>> # (2) Use flash attention 3 |
47 | 47 | >>> # pipe.transformer.set_attention_backend("_flash_3") |
48 | | - |
49 | | - >>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。" |
| 48 | + >>> prompt = '一幅为名为"造相「Z-IMAGE-TURBO」"的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。' |
50 | 49 | >>> image = pipe( |
51 | 50 | ... prompt, |
52 | 51 | ... height=1024, |
@@ -432,26 +431,35 @@ def __call__( |
432 | 431 | elif prompt is not None and isinstance(prompt, list): |
433 | 432 | batch_size = len(prompt) |
434 | 433 | else: |
435 | | - batch_size = prompt_embeds.shape[0] |
| 434 | + batch_size = len(prompt_embeds) |
436 | 435 |
|
437 | 436 | lora_scale = ( |
438 | 437 | self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None |
439 | 438 | ) |
440 | | - ( |
441 | | - prompt_embeds, |
442 | | - negative_prompt_embeds, |
443 | | - ) = self.encode_prompt( |
444 | | - prompt=prompt, |
445 | | - negative_prompt=negative_prompt, |
446 | | - do_classifier_free_guidance=self.do_classifier_free_guidance, |
447 | | - prompt_embeds=prompt_embeds, |
448 | | - negative_prompt_embeds=negative_prompt_embeds, |
449 | | - dtype=dtype, |
450 | | - device=device, |
451 | | - num_images_per_prompt=num_images_per_prompt, |
452 | | - max_sequence_length=max_sequence_length, |
453 | | - lora_scale=lora_scale, |
454 | | - ) |
| 439 | + |
| 440 | + # If prompt_embeds is provided and prompt is None, skip encoding |
| 441 | + if prompt_embeds is not None and prompt is None: |
| 442 | + if self.do_classifier_free_guidance and negative_prompt_embeds is None: |
| 443 | + raise ValueError( |
| 444 | + "When `prompt_embeds` is provided without `prompt`, " |
| 445 | + "`negative_prompt_embeds` must also be provided for classifier-free guidance." |
| 446 | + ) |
| 447 | + else: |
| 448 | + ( |
| 449 | + prompt_embeds, |
| 450 | + negative_prompt_embeds, |
| 451 | + ) = self.encode_prompt( |
| 452 | + prompt=prompt, |
| 453 | + negative_prompt=negative_prompt, |
| 454 | + do_classifier_free_guidance=self.do_classifier_free_guidance, |
| 455 | + prompt_embeds=prompt_embeds, |
| 456 | + negative_prompt_embeds=negative_prompt_embeds, |
| 457 | + dtype=dtype, |
| 458 | + device=device, |
| 459 | + num_images_per_prompt=num_images_per_prompt, |
| 460 | + max_sequence_length=max_sequence_length, |
| 461 | + lora_scale=lora_scale, |
| 462 | + ) |
455 | 463 |
|
456 | 464 | # 4. Prepare latent variables |
457 | 465 | num_channels_latents = self.transformer.in_channels |
|
0 commit comments