Skip to content

Commit 47cc046

Browse files
committed
1. add bf16 pth file path;
2. add complex human instruct in pipeline;
1 parent 34c5880 commit 47cc046

File tree

2 files changed

+34
-5
lines changed

2 files changed

+34
-5
lines changed

scripts/convert_sana_to_diffusers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
ckpt_ids = [
2828
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing",
29+
"Efficient-Large-Model/Sana_1600M_1024px_BF16",
2930
"Efficient-Large-Model/Sana_1600M_512px_MultiLing",
3031
"Efficient-Large-Model/Sana_1600M_1024px",
3132
"Efficient-Large-Model/Sana_1600M_512px",
@@ -39,7 +40,7 @@ def main(args):
3940
ckpt_id = ckpt_ids[0]
4041
cache_dir_path = os.path.expanduser("~/.cache/huggingface/hub")
4142

42-
if args.orig_ckpt_path is None:
43+
if args.orig_ckpt_path is None or args.orig_ckpt_path in ckpt_ids:
4344
snapshot_download(
4445
repo_id=ckpt_id,
4546
cache_dir=cache_dir_path,
@@ -169,7 +170,7 @@ def main(args):
169170
caption_channels=2304,
170171
mlp_ratio=2.5,
171172
attention_bias=False,
172-
sample_size=32,
173+
sample_size=args.image_size // 32,
173174
patch_size=1,
174175
norm_elementwise_affine=False,
175176
norm_eps=1e-6,
@@ -191,6 +192,8 @@ def main(args):
191192
num_model_params = sum(p.numel() for p in transformer.parameters())
192193
print(f"Total number of transformer parameters: {num_model_params}")
193194

195+
transformer = transformer.to(weight_dtype)
196+
194197
if not args.save_full_pipeline:
195198
print(
196199
colored(
@@ -200,7 +203,6 @@ def main(args):
200203
attrs=["bold"],
201204
)
202205
)
203-
transformer = transformer.to(weight_dtype)
204206
transformer.save_pretrained(
205207
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB", variant=variant
206208
)

src/diffusers/pipelines/sana/pipeline_sana.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def encode_prompt(
166166
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
167167
clean_caption: bool = False,
168168
max_sequence_length: int = 300,
169+
complex_huamen_instruction: list[str] = [],
169170
):
170171
r"""
171172
Encodes the prompt into text encoder hidden states.
@@ -192,6 +193,8 @@ def encode_prompt(
192193
clean_caption (`bool`, defaults to `False`):
193194
If `True`, the function will preprocess and clean the provided caption before encoding.
194195
max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
196+
use_complex_huamen_instruction (`list[str]`, defaults to `complex_huamen_instruction`):
197+
If `complex_huamen_instruction` is not empty, the function will use the complex Huamen instruction for the prompt.
195198
"""
196199

197200
if device is None:
@@ -206,13 +209,24 @@ def encode_prompt(
206209

207210
# See Section 3.1. of the paper.
208211
max_length = max_sequence_length
212+
select_index = [0] + list(range(-max_length + 1, 0))
209213

210214
if prompt_embeds is None:
211215
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
216+
217+
# prepare complex huamen instruction
218+
if not complex_huamen_instruction:
219+
max_length_all = max_length
220+
else:
221+
chi_prompt = "\n".join(complex_huamen_instruction)
222+
prompt = [chi_prompt + p for p in prompt]
223+
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
224+
max_length_all = num_chi_prompt_tokens + max_length - 2
225+
212226
text_inputs = self.tokenizer(
213227
prompt,
214228
padding="max_length",
215-
max_length=max_length,
229+
max_length=max_length_all,
216230
truncation=True,
217231
add_special_tokens=True,
218232
return_tensors="pt",
@@ -223,7 +237,8 @@ def encode_prompt(
223237
prompt_attention_mask = prompt_attention_mask.to(device)
224238

225239
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
226-
prompt_embeds = prompt_embeds[0]
240+
prompt_embeds = prompt_embeds[0][:, select_index]
241+
prompt_attention_mask = prompt_attention_mask[:, select_index]
227242

228243
if self.transformer is not None:
229244
dtype = self.transformer.dtype
@@ -566,6 +581,16 @@ def __call__(
566581
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
567582
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
568583
max_sequence_length: int = 300,
584+
complex_human_attention: list[str] = [
585+
'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:',
586+
'- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.',
587+
'- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.',
588+
'Here are examples of how to transform or refine prompts:',
589+
'- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.',
590+
'- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.',
591+
'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:',
592+
'User Prompt: '
593+
],
569594
) -> Union[SanaPipelineOutput, Tuple]:
570595
"""
571596
Function invoked when calling the pipeline for generation.
@@ -644,6 +669,8 @@ def __call__(
644669
`._callback_tensor_inputs` attribute of your pipeline class.
645670
max_sequence_length (`int` defaults to `300`):
646671
Maximum sequence length to use with the `prompt`.
672+
complex_human_attention (`list[str]`, *optional*):
673+
Instructions for complex human attention: https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.
647674
648675
Examples:
649676

0 commit comments

Comments
 (0)