Skip to content

Commit 77571a8

Browse files
committed
refactor
1 parent 6379241 commit 77571a8

File tree

2 files changed

+34
-19
lines changed

2 files changed

+34
-19
lines changed

scripts/convert_dcae_to_diffusers.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,10 @@ def get_vae_config(name: str):
136136
elif name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]:
137137
config = {
138138
"latent_channels": 32,
139-
"encoder_block_type": ["ResBlock", "ResBlock", "ResBlock", "EViT_GLU", "EViT_GLU", "EViT_GLU"],
139+
"encoder_block_type": ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock", "EfficientViTBlock", "EfficientViTBlock"],
140140
"block_out_channels": [128, 256, 512, 512, 1024, 1024],
141141
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2],
142-
"decoder_block_type": ["ResBlock", "ResBlock", "ResBlock", "EViT_GLU", "EViT_GLU", "EViT_GLU"],
142+
"decoder_block_type": ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock", "EfficientViTBlock", "EfficientViTBlock"],
143143
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2],
144144
"decoder_norm": ["bn2d", "bn2d", "bn2d", "rms2d", "rms2d", "rms2d"],
145145
"decoder_act": ["relu", "relu", "relu", "silu", "silu", "silu"],
@@ -151,23 +151,23 @@ def get_vae_config(name: str):
151151
"ResBlock",
152152
"ResBlock",
153153
"ResBlock",
154-
"EViT_GLU",
155-
"EViT_GLU",
156-
"EViT_GLU",
157-
"EViT_GLU",
158-
"EViT_GLU",
154+
"EfficientViTBlock",
155+
"EfficientViTBlock",
156+
"EfficientViTBlock",
157+
"EfficientViTBlock",
158+
"EfficientViTBlock",
159159
],
160160
"block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
161161
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2],
162162
"decoder_block_type": [
163163
"ResBlock",
164164
"ResBlock",
165165
"ResBlock",
166-
"EViT_GLU",
167-
"EViT_GLU",
168-
"EViT_GLU",
169-
"EViT_GLU",
170-
"EViT_GLU",
166+
"EfficientViTBlock",
167+
"EfficientViTBlock",
168+
"EfficientViTBlock",
169+
"EfficientViTBlock",
170+
"EfficientViTBlock",
171171
],
172172
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2],
173173
"decoder_norm": ["bn2d", "bn2d", "bn2d", "rms2d", "rms2d", "rms2d", "rms2d", "rms2d"],
@@ -176,10 +176,10 @@ def get_vae_config(name: str):
176176
elif name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]:
177177
config = {
178178
"latent_channels": 128,
179-
"encoder_block_type": ["ResBlock", "ResBlock", "ResBlock", "EViT_GLU", "EViT_GLU", "EViT_GLU", "EViT_GLU"],
179+
"encoder_block_type": ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock", "EfficientViTBlock", "EfficientViTBlock", "EfficientViTBlock"],
180180
"block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
181181
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2],
182-
"decoder_block_type": ["ResBlock", "ResBlock", "ResBlock", "EViT_GLU", "EViT_GLU", "EViT_GLU", "EViT_GLU"],
182+
"decoder_block_type": ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock", "EfficientViTBlock", "EfficientViTBlock", "EfficientViTBlock"],
183183
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2],
184184
"decoder_norm": ["bn2d", "bn2d", "bn2d", "rms2d", "rms2d", "rms2d", "rms2d"],
185185
"decoder_act": ["relu", "relu", "relu", "silu", "silu", "silu", "silu"],

src/diffusers/models/autoencoders/autoencoder_dc.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,11 @@ def __init__(
8080
in_channels: int,
8181
out_channels: int,
8282
norm_type: str = "bn2d",
83-
act_func=("relu6", None),
83+
act_fn: str = "relu6",
8484
) -> None:
8585
super().__init__()
8686

87-
act_func = val2tuple(act_func, 2)
88-
89-
self.nonlinearity = get_activation(act_func[0]) if act_func[0] is not None else nn.Identity()
87+
self.nonlinearity = get_activation(act_fn) if act_fn is not None else nn.Identity()
9088

9189
self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
9290
self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False)
@@ -259,6 +257,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
259257
return x
260258

261259

260+
def get_block_from_block_type(
261+
block_type: str,
262+
in_channels: int,
263+
out_channels: int,
264+
norm_type: str,
265+
act_fn: str,
266+
):
267+
if block_type == "ResBlock":
268+
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
269+
270+
elif block_type == "EfficientViTBlock":
271+
block = EfficientViTBlock(in_channels, norm=norm_type, scales=())
272+
273+
else:
274+
raise ValueError(f"Block with {block_type=} is not supported.")
275+
276+
262277
def build_stage_main(
263278
width: int, depth: int, block_type: str | List[str], norm: str, act: str, input_width: int
264279
) -> list[nn.Module]:
@@ -278,7 +293,7 @@ def build_stage_main(
278293
norm_type=norm,
279294
act_func=(act, None),
280295
)
281-
elif current_block_type == "EViT_GLU":
296+
elif current_block_type == "EfficientViTBlock":
282297
assert in_channels == out_channels
283298
block = EfficientViTBlock(in_channels, norm=norm, scales=())
284299
elif current_block_type == "EViTS5_GLU":

0 commit comments

Comments
 (0)