Skip to content

Commit b76493f

Browse files
committed
change the ae related code due to the latest update of DCAE branch;
1 parent 5687ba1 commit b76493f

File tree

2 files changed

+14
-26
lines changed

2 files changed

+14
-26
lines changed

scripts/convert_sana_pag_to_diffusers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,8 @@ def main(args):
188188
print(colored(f"Saving the whole SanaPAGPipeline containing {args.model_type}", "green", attrs=["bold"]))
189189
# VAE
190190
ae = AutoencoderDC.from_pretrained(
191-
"Efficient-Large-Model/dc_ae_f32c32_sana_1.0_diffusers",
192-
torch_dtype=torch.float32,
191+
"mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers",
192+
torch_dtype=torch.bfloat16,
193193
).to(device)
194194

195195
# Text Encoder

src/diffusers/models/transformers/sana_transformer_2d.py

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -60,30 +60,18 @@ def __init__(self, dim, eps: float, elementwise_affine: bool = True, scale_facto
6060

6161
# Modified from diffusers.models.autoencoders.ecae.GLUMBConv
6262
@maybe_allow_in_graph
63-
class SanaGLUMBConv(GLUMBConv):
64-
def __init__(
65-
self,
66-
in_channels: int,
67-
out_channels: int,
68-
kernel_size=3,
69-
stride=1,
70-
mid_channels=None,
71-
expand_ratio=2.5,
72-
use_bias=False,
73-
norm=(None, None, None),
74-
act_func=("silu", "silu", None),
75-
):
76-
super().__init__(
77-
in_channels=in_channels,
78-
out_channels=out_channels,
79-
kernel_size=kernel_size,
80-
stride=stride,
81-
mid_channels=mid_channels,
82-
expand_ratio=expand_ratio,
83-
use_bias=use_bias,
84-
norm=norm,
85-
act_func=act_func,
86-
)
63+
class SanaGLUMBConv(nn.Module):
64+
def __init__(self, in_channels: int, out_channels: int) -> None:
65+
super().__init__()
66+
67+
hidden_channels = int(2.5 * in_channels)
68+
69+
self.nonlinearity = nn.SiLU()
70+
71+
self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
72+
self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
73+
self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
74+
self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
8775

8876
def forward(self, x: torch.Tensor, HW=None) -> torch.Tensor:
8977
B, N, C = x.shape

0 commit comments

Comments
 (0)