|
7 | 7 |
|
8 | 8 | import torch |
9 | 9 | from accelerate import init_empty_weights |
| 10 | +from huggingface_hub import hf_hub_download, snapshot_download |
10 | 11 | from termcolor import colored |
11 | 12 | from transformers import AutoModelForCausalLM, AutoTokenizer |
12 | 13 |
|
|
23 | 24 |
|
24 | 25 | CTX = init_empty_weights if is_accelerate_available else nullcontext |
25 | 26 |
|
26 | | -ckpt_id = "Sana" |
| 27 | +ckpt_ids = [ |
| 28 | + "Efficient-Large-Model/Sana_1600M_1024px_MultiLing", |
| 29 | + "Efficient-Large-Model/Sana_1600M_512px_MultiLing", |
| 30 | + "Efficient-Large-Model/Sana_1600M_1024px", |
| 31 | + "Efficient-Large-Model/Sana_1600M_512px", |
| 32 | + "Efficient-Large-Model/Sana_600M_1024px", |
| 33 | + "Efficient-Large-Model/Sana_600M_512px", |
| 34 | +] |
27 | 35 | # https://github.com/NVlabs/Sana/blob/main/scripts/inference.py |
28 | 36 |
|
29 | 37 |
|
30 | 38 | def main(args): |
31 | | - all_state_dict = torch.load(args.orig_ckpt_path, map_location=torch.device("cpu")) |
| 39 | + ckpt_id = ckpt_ids[0] |
| 40 | + cache_dir_path = os.path.expanduser("~/.cache/huggingface/hub") |
| 41 | + if args.orig_ckpt_path is None: |
| 42 | + snapshot_download( |
| 43 | + repo_id=ckpt_id, |
| 44 | + cache_dir=cache_dir_path, |
| 45 | + repo_type="model", |
| 46 | + ) |
| 47 | + file_path = hf_hub_download( |
| 48 | + repo_id=ckpt_id, |
| 49 | + filename=f"checkpoints/{ckpt_id.split('/')[-1]}.pth", |
| 50 | + cache_dir=cache_dir_path, |
| 51 | + repo_type="model", |
| 52 | + ) |
| 53 | + else: |
| 54 | + file_path = args.orig_ckpt_path |
| 55 | + all_state_dict = torch.load(file_path, weights_only=True) |
32 | 56 | state_dict = all_state_dict.pop("state_dict") |
33 | 57 | converted_state_dict = {} |
34 | 58 |
|
@@ -143,7 +167,6 @@ def main(args): |
143 | 167 | attention_bias=False, |
144 | 168 | sample_size=32, |
145 | 169 | patch_size=1, |
146 | | - activation_fn=("silu", "silu", None), |
147 | 170 | upcast_attention=False, |
148 | 171 | norm_type="ada_norm_single", |
149 | 172 | norm_elementwise_affine=False, |
@@ -175,7 +198,7 @@ def main(args): |
175 | 198 | print( |
176 | 199 | colored( |
177 | 200 | f"Only saving transformer model of {args.model_type}. " |
178 | | - f"Set --save_full_pipeline to save the whole SanaPipeline", |
| 201 | + f"Set --save_full_pipeline to save the whole SanaPAGPipeline", |
179 | 202 | "green", |
180 | 203 | attrs=["bold"], |
181 | 204 | ) |
|
0 commit comments