Skip to content

Commit 7fa435f

Browse files
committed
1. remove un-unsed parameters in init;
2. code update;
1 parent f0aa9b9 commit 7fa435f

File tree

3 files changed

+28
-9
lines changed

3 files changed

+28
-9
lines changed

scripts/convert_sana_pag_to_diffusers.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import torch
99
from accelerate import init_empty_weights
10+
from huggingface_hub import hf_hub_download, snapshot_download
1011
from termcolor import colored
1112
from transformers import AutoModelForCausalLM, AutoTokenizer
1213

@@ -23,12 +24,35 @@
2324

2425
CTX = init_empty_weights if is_accelerate_available else nullcontext
2526

26-
ckpt_id = "Sana"
27+
ckpt_ids = [
28+
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing",
29+
"Efficient-Large-Model/Sana_1600M_512px_MultiLing",
30+
"Efficient-Large-Model/Sana_1600M_1024px",
31+
"Efficient-Large-Model/Sana_1600M_512px",
32+
"Efficient-Large-Model/Sana_600M_1024px",
33+
"Efficient-Large-Model/Sana_600M_512px",
34+
]
2735
# https://github.com/NVlabs/Sana/blob/main/scripts/inference.py
2836

2937

3038
def main(args):
31-
all_state_dict = torch.load(args.orig_ckpt_path, map_location=torch.device("cpu"))
39+
ckpt_id = ckpt_ids[0]
40+
cache_dir_path = os.path.expanduser("~/.cache/huggingface/hub")
41+
if args.orig_ckpt_path is None:
42+
snapshot_download(
43+
repo_id=ckpt_id,
44+
cache_dir=cache_dir_path,
45+
repo_type="model",
46+
)
47+
file_path = hf_hub_download(
48+
repo_id=ckpt_id,
49+
filename=f"checkpoints/{ckpt_id.split('/')[-1]}.pth",
50+
cache_dir=cache_dir_path,
51+
repo_type="model",
52+
)
53+
else:
54+
file_path = args.orig_ckpt_path
55+
all_state_dict = torch.load(file_path, weights_only=True)
3256
state_dict = all_state_dict.pop("state_dict")
3357
converted_state_dict = {}
3458

@@ -143,7 +167,6 @@ def main(args):
143167
attention_bias=False,
144168
sample_size=32,
145169
patch_size=1,
146-
activation_fn=("silu", "silu", None),
147170
upcast_attention=False,
148171
norm_type="ada_norm_single",
149172
norm_elementwise_affine=False,
@@ -175,7 +198,7 @@ def main(args):
175198
print(
176199
colored(
177200
f"Only saving transformer model of {args.model_type}. "
178-
f"Set --save_full_pipeline to save the whole SanaPipeline",
201+
f"Set --save_full_pipeline to save the whole SanaPAGPipeline",
179202
"green",
180203
attrs=["bold"],
181204
)

scripts/convert_sana_to_diffusers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def main(args):
5252
)
5353
else:
5454
file_path = args.orig_ckpt_path
55-
all_state_dict = torch.load(file_path, map_location=torch.device("cpu"))
55+
all_state_dict = torch.load(file_path, weights_only=True)
5656
state_dict = all_state_dict.pop("state_dict")
5757
converted_state_dict = {}
5858

@@ -167,7 +167,6 @@ def main(args):
167167
attention_bias=False,
168168
sample_size=32,
169169
patch_size=1,
170-
activation_fn=("silu", "silu", None),
171170
upcast_attention=False,
172171
norm_type="ada_norm_single",
173172
norm_elementwise_affine=False,

src/diffusers/models/transformers/sana_transformer.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,8 +267,6 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin):
267267
The width of the latent images. This parameter is fixed during training.
268268
patch_size (int, defaults to 1):
269269
Size of the patches the model processes, relevant for architectures working on non-sequential data.
270-
activation_fn (str, optional, defaults to "gelu-approximate"):
271-
Activation function to use in feed-forward networks within Transformer blocks.
272270
num_embeds_ada_norm (int, optional, defaults to 1000):
273271
Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during
274272
inference.
@@ -308,7 +306,6 @@ def __init__(
308306
attention_bias: bool = True,
309307
sample_size: int = 32,
310308
patch_size: int = 1,
311-
activation_fn: tuple = None,
312309
num_embeds_ada_norm: Optional[int] = 1000,
313310
upcast_attention: bool = False,
314311
norm_type: str = "ada_norm_single",

0 commit comments

Comments
 (0)