Skip to content

Commit 7c23fcc

Browse files
committed
fix the complex_human_instruct bug and typo;
1 parent f413407 commit 7c23fcc

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

src/diffusers/pipelines/sana/pipeline_sana.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def encode_prompt(
165165
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
166166
clean_caption: bool = False,
167167
max_sequence_length: int = 300,
168-
complex_huamen_instruction: list[str] = [],
168+
complex_human_instruction=None,
169169
):
170170
r"""
171171
Encodes the prompt into text encoder hidden states.
@@ -187,13 +187,13 @@ def encode_prompt(
187187
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
188188
provided, text embeddings will be generated from `prompt` input argument.
189189
negative_prompt_embeds (`torch.Tensor`, *optional*):
190-
Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
190+
Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the ""
191191
string.
192192
clean_caption (`bool`, defaults to `False`):
193193
If `True`, the function will preprocess and clean the provided caption before encoding.
194194
max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
195-
use_complex_huamen_instruction (`list[str]`, defaults to `complex_huamen_instruction`):
196-
If `complex_huamen_instruction` is not empty, the function will use the complex Huamen instruction for
195+
complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
196+
If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
197197
the prompt.
198198
"""
199199

@@ -214,11 +214,11 @@ def encode_prompt(
214214
if prompt_embeds is None:
215215
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
216216

217-
# prepare complex huamen instruction
218-
if not complex_huamen_instruction:
217+
# prepare complex human instruction
218+
if not complex_human_instruction:
219219
max_length_all = max_length
220220
else:
221-
chi_prompt = "\n".join(complex_huamen_instruction)
221+
chi_prompt = "\n".join(complex_human_instruction)
222222
prompt = [chi_prompt + p for p in prompt]
223223
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
224224
max_length_all = num_chi_prompt_tokens + max_length - 2
@@ -581,7 +581,7 @@ def __call__(
581581
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
582582
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
583583
max_sequence_length: int = 300,
584-
complex_human_attention: list[str] = [
584+
complex_human_instruction: list[str] = [
585585
'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:',
586586
"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.",
587587
"- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
@@ -669,7 +669,7 @@ def __call__(
669669
`._callback_tensor_inputs` attribute of your pipeline class.
670670
max_sequence_length (`int` defaults to `300`):
671671
Maximum sequence length to use with the `prompt`.
672-
complex_human_attention (`list[str]`, *optional*):
672+
complex_human_instruction (`list[str]`, *optional*):
673673
Instructions for complex human attention:
674674
https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.
675675
@@ -740,6 +740,7 @@ def __call__(
740740
negative_prompt_attention_mask=negative_prompt_attention_mask,
741741
clean_caption=clean_caption,
742742
max_sequence_length=max_sequence_length,
743+
complex_human_instruction=complex_human_instruction,
743744
)
744745
if self.do_classifier_free_guidance:
745746
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)

0 commit comments

Comments
 (0)