Skip to content

Commit ad3935f

Browse files
committed
1. fix bugs and run convert script success;
2. Downloading ckpt from hub automatically;
1 parent e7c1a59 commit ad3935f

File tree

5 files changed

+32
-7
lines changed

5 files changed

+32
-7
lines changed

scripts/convert_sana_to_diffusers.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,39 @@
1919
)
2020
from diffusers.models.modeling_utils import load_model_dict_into_meta
2121
from diffusers.utils.import_utils import is_accelerate_available
22-
22+
from huggingface_hub import hf_hub_download, snapshot_download
2323

2424
CTX = init_empty_weights if is_accelerate_available else nullcontext
2525

26-
ckpt_id = "Sana"
26+
ckpt_ids = [
27+
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing",
28+
"Efficient-Large-Model/Sana_1600M_512px_MultiLing",
29+
"Efficient-Large-Model/Sana_1600M_1024px",
30+
"Efficient-Large-Model/Sana_1600M_512px",
31+
"Efficient-Large-Model/Sana_600M_1024px",
32+
"Efficient-Large-Model/Sana_600M_512px",
33+
]
2734
# https://github.com/NVlabs/Sana/blob/main/scripts/inference.py
2835

2936

3037
def main(args):
31-
all_state_dict = torch.load(args.orig_ckpt_path, map_location=torch.device("cpu"))
38+
ckpt_id = ckpt_ids[0]
39+
cache_dir_path = os.path.expanduser("~/.cache/huggingface/hub")
40+
if args.orig_ckpt_path is None:
41+
snapshot_download(
42+
repo_id=ckpt_id,
43+
cache_dir=cache_dir_path,
44+
repo_type="model",
45+
)
46+
file_path = hf_hub_download(
47+
repo_id=ckpt_id,
48+
filename=f"checkpoints/{ckpt_id.split('/')[-1]}.pth",
49+
cache_dir=cache_dir_path,
50+
repo_type="model",
51+
)
52+
else:
53+
file_path = args.orig_ckpt_path
54+
all_state_dict = torch.load(file_path, map_location=torch.device("cpu"))
3255
state_dict = all_state_dict.pop("state_dict")
3356
converted_state_dict = {}
3457

src/diffusers/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
_import_structure["transformers.latte_transformer_3d"] = ["LatteTransformer3DModel"]
5858
_import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
5959
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
60-
_import_structure["transformers.sana_transformer_2d"] = ["SanaTransformer2DModel"]
60+
_import_structure["transformers.sana_transformer"] = ["SanaTransformer2DModel"]
6161
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
6262
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
6363
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]

src/diffusers/models/autoencoders/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@
88
from .autoencoder_oobleck import AutoencoderOobleck
99
from .autoencoder_tiny import AutoencoderTiny
1010
from .consistency_decoder_vae import ConsistencyDecoderVAE
11-
from .autoencoder_dc import DCAE
11+
from .autoencoder_dc import AutoencoderDC
1212
from .vq_model import VQModel

src/diffusers/models/normalization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ def get_normalization(
594594

595595
class RMSNormScaled(nn.Module):
596596
def __init__(self, dim, eps: float, elementwise_affine: bool = True, scale_factor: float = 1.0, bias: bool = False):
597-
super().__init__(dim, eps, elementwise_affine)
597+
super().__init__()
598598
self.weight = nn.Parameter(torch.ones(dim) * scale_factor)
599599

600600
self.eps = eps

src/diffusers/models/transformers/sana_transformer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,6 @@ def forward(
248248

249249

250250
class SanaTransformer2DModel(ModelMixin, ConfigMixin):
251-
# TODO: Change pixart name below
252251
r"""
253252
A 2D Transformer model as introduced in Sana family of models (https://arxiv.org/abs/2310.00426,
254253
https://arxiv.org/abs/2403.04692).
@@ -272,6 +271,8 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin):
272271
The width of the latent images. This parameter is fixed during training.
273272
patch_size (int, defaults to 1):
274273
Size of the patches the model processes, relevant for architectures working on non-sequential data.
274+
activation_fn (str, optional, defaults to "gelu-approximate"):
275+
Activation function to use in feed-forward networks within Transformer blocks.
275276
num_embeds_ada_norm (int, optional, defaults to 1000):
276277
Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during
277278
inference.
@@ -311,6 +312,7 @@ def __init__(
311312
attention_bias: bool = True,
312313
sample_size: int = 32,
313314
patch_size: int = 1,
315+
activation_fn: tuple = None,
314316
num_embeds_ada_norm: Optional[int] = 1000,
315317
upcast_attention: bool = False,
316318
norm_type: str = "ada_norm_single",

0 commit comments

Comments
 (0)