Skip to content

Commit 1f8a3b3

Browse files
committed
update
1 parent c1c02a2 commit 1f8a3b3

File tree

2 files changed

+48
-117
lines changed

2 files changed

+48
-117
lines changed

scripts/convert_dcae_to_diffusers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,18 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
2020

2121
VAE_KEYS_RENAME_DICT = {
2222
# common
23+
"main.": "",
2324
"op_list.": "",
2425
"norm.": "norm.norm.",
26+
"depth_conv": "conv_depth",
27+
"point_conv": "conv_point",
28+
"inverted_conv": "conv_inverted",
29+
"conv.conv.": "conv.",
2530
# encoder
2631
"encoder.project_in.conv": "encoder.conv_in",
27-
"encoder.project_out.main.0.conv": "encoder.conv_out",
32+
"encoder.project_out.0.conv": "encoder.conv_out",
2833
# decoder
29-
"decoder.project_in.main.conv": "decoder.conv_in",
34+
"decoder.project_in.conv": "decoder.conv_in",
3035
"decoder.project_out.0": "decoder.norm_out.norm",
3136
"decoder.project_out.2.conv": "decoder.conv_out",
3237
}

src/diffusers/models/autoencoders/autoencoder_dc.py

Lines changed: 41 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -111,59 +111,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
111111
return x
112112

113113

114-
class UpsamplePixelShuffle(nn.Module):
115-
def __init__(
116-
self,
117-
in_channels: int,
118-
out_channels: int,
119-
kernel_size: int,
120-
factor: int,
121-
):
122-
super().__init__()
123-
self.factor = factor
124-
out_ratio = factor**2
125-
self.conv = DCConv2d(
126-
in_channels=in_channels,
127-
out_channels=out_channels * out_ratio,
128-
kernel_size=kernel_size,
129-
use_bias=True,
130-
norm=None,
131-
act_func=None,
132-
)
133-
134-
def forward(self, x: torch.Tensor) -> torch.Tensor:
135-
x = self.conv(x)
136-
x = F.pixel_shuffle(x, self.factor)
137-
return x
138-
139-
140-
class UpsampleInterpolate(nn.Module):
141-
def __init__(
142-
self,
143-
in_channels: int,
144-
out_channels: int,
145-
kernel_size: int,
146-
factor: int,
147-
mode: str = "nearest",
148-
) -> None:
149-
super().__init__()
150-
self.factor = factor
151-
self.mode = mode
152-
self.conv = DCConv2d(
153-
in_channels=in_channels,
154-
out_channels=out_channels,
155-
kernel_size=kernel_size,
156-
use_bias=True,
157-
norm=None,
158-
act_func=None,
159-
)
160-
161-
def forward(self, x: torch.Tensor) -> torch.Tensor:
162-
x = torch.nn.functional.interpolate(x, scale_factor=self.factor, mode=self.mode)
163-
x = self.conv(x)
164-
return x
165-
166-
167114
class UpsampleChannelDuplicatingPixelUnshuffle(nn.Module):
168115
def __init__(
169116
self,
@@ -184,11 +131,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
184131
return x
185132

186133

187-
class IdentityLayer(nn.Module):
188-
def forward(self, x: torch.Tensor) -> torch.Tensor:
189-
return x
190-
191-
192134
class GLUMBConv(nn.Module):
193135
def __init__(
194136
self,
@@ -210,15 +152,15 @@ def __init__(
210152
mid_channels = round(in_channels * expand_ratio) if mid_channels is None else mid_channels
211153

212154
self.glu_act = get_activation(act_func[1])
213-
self.inverted_conv = DCConv2d(
155+
self.conv_inverted = DCConv2d(
214156
in_channels,
215157
mid_channels * 2,
216158
1,
217159
use_bias=use_bias[0],
218160
norm=norm[0],
219161
act_func=act_func[0],
220162
)
221-
self.depth_conv = DCConv2d(
163+
self.conv_depth = DCConv2d(
222164
mid_channels * 2,
223165
mid_channels * 2,
224166
kernel_size,
@@ -228,7 +170,7 @@ def __init__(
228170
norm=norm[1],
229171
act_func=None,
230172
)
231-
self.point_conv = DCConv2d(
173+
self.conv_point = DCConv2d(
232174
mid_channels,
233175
out_channels,
234176
1,
@@ -238,15 +180,16 @@ def __init__(
238180
)
239181

240182
def forward(self, x: torch.Tensor) -> torch.Tensor:
241-
x = self.inverted_conv(x)
242-
x = self.depth_conv(x)
183+
residual = x
184+
x = self.conv_inverted(x)
185+
x = self.conv_depth(x)
243186

244187
x, gate = torch.chunk(x, 2, dim=1)
245188
gate = self.glu_act(gate)
246189
x = x * gate
247190

248-
x = self.point_conv(x)
249-
return x
191+
x = self.conv_point(x)
192+
return x + residual
250193

251194

252195
class ResBlock(nn.Module):
@@ -289,9 +232,10 @@ def __init__(
289232
)
290233

291234
def forward(self, x: torch.Tensor) -> torch.Tensor:
235+
residual = x
292236
x = self.conv1(x)
293237
x = self.conv2(x)
294-
return x
238+
return x + residual
295239

296240

297241
class LiteMLA(nn.Module):
@@ -357,7 +301,6 @@ def __init__(
357301
act_func=act_func[1],
358302
)
359303

360-
@torch.autocast(device_type="cuda", enabled=False)
361304
def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor:
362305
B, _, H, W = list(qkv.size())
363306

@@ -429,6 +372,7 @@ def relu_quadratic_att(self, qkv: torch.Tensor) -> torch.Tensor:
429372
return out
430373

431374
def forward(self, x: torch.Tensor) -> torch.Tensor:
375+
residual = x
432376
# generate multi-scale q, k, v
433377
qkv = self.qkv(x)
434378
multi_scale_qkv = [qkv]
@@ -443,7 +387,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
443387
out = self.relu_quadratic_att(qkv)
444388
out = self.proj(out)
445389

446-
return out
390+
return out + residual
447391

448392

449393
class EfficientViTBlock(nn.Module):
@@ -461,30 +405,24 @@ def __init__(
461405
):
462406
super().__init__()
463407
if context_module == "LiteMLA":
464-
self.context_module = ResidualBlock(
465-
LiteMLA(
466-
in_channels=in_channels,
467-
out_channels=in_channels,
468-
heads_ratio=heads_ratio,
469-
dim=dim,
470-
norm=(None, norm),
471-
scales=scales,
472-
),
473-
IdentityLayer(),
408+
self.context_module = LiteMLA(
409+
in_channels=in_channels,
410+
out_channels=in_channels,
411+
heads_ratio=heads_ratio,
412+
dim=dim,
413+
norm=(None, norm),
414+
scales=scales,
474415
)
475416
else:
476417
raise ValueError(f"context_module {context_module} is not supported")
477418
if local_module == "GLUMBConv":
478-
self.local_module = ResidualBlock(
479-
GLUMBConv(
480-
in_channels=in_channels,
481-
out_channels=in_channels,
482-
expand_ratio=expand_ratio,
483-
use_bias=(True, True, False),
484-
norm=(None, None, norm),
485-
act_func=(act_func, act_func, None),
486-
),
487-
IdentityLayer(),
419+
self.local_module = GLUMBConv(
420+
in_channels=in_channels,
421+
out_channels=in_channels,
422+
expand_ratio=expand_ratio,
423+
use_bias=(True, True, False),
424+
norm=(None, None, norm),
425+
act_func=(act_func, act_func, None),
488426
)
489427
else:
490428
raise NotImplementedError(f"local_module {local_module} is not supported")
@@ -546,7 +484,7 @@ def build_stage_main(
546484

547485
if current_block_type == "ResBlock":
548486
assert in_channels == out_channels
549-
main_block = ResBlock(
487+
block = ResBlock(
550488
in_channels=in_channels,
551489
out_channels=out_channels,
552490
kernel_size=3,
@@ -555,7 +493,6 @@ def build_stage_main(
555493
norm=(None, norm),
556494
act_func=(act, None),
557495
)
558-
block = ResidualBlock(main_block, IdentityLayer())
559496
elif current_block_type == "EViT_GLU":
560497
assert in_channels == out_channels
561498
block = EfficientViTBlock(in_channels, norm=norm, act_func=act, local_module="GLUMBConv", scales=())
@@ -619,7 +556,7 @@ def __init__(
619556
self.conv = nn.Conv2d(
620557
in_channels,
621558
out_channels,
622-
kernel_size=(kernel_size, kernel_size),
559+
kernel_size=kernel_size,
623560
stride=self.stride,
624561
padding=kernel_size // 2,
625562
)
@@ -654,32 +591,21 @@ def __init__(
654591
super().__init__()
655592

656593
self.interpolate = interpolate
657-
self.interpolation_method = interpolation_mode
594+
self.interpolation_mode = interpolation_mode
658595
self.factor = 2
659596
self.stride = 1
660597

661598
out_ratio = self.factor ** 2
662599
if not interpolate:
663600
out_channels = out_channels * out_ratio
664601

665-
if interpolate:
666-
nn.conv = DCConv2d(
667-
in_channels=in_channels,
668-
out_channels=out_channels,
669-
kernel_size=kernel_size,
670-
)
671-
else:
672-
self.conv = DCConv2d(
673-
in_channels=in_channels,
674-
out_channels=out_channels,
675-
kernel_size=kernel_size,
676-
use_bias=True,
677-
norm=None,
678-
act_func=None,
679-
)
680-
self.conv = UpsamplePixelShuffle(
681-
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, factor=2
682-
)
602+
self.conv = nn.Conv2d(
603+
in_channels,
604+
out_channels,
605+
kernel_size=kernel_size,
606+
stride=self.stride,
607+
padding=kernel_size // 2,
608+
)
683609

684610
self.shortcut = None
685611
if shortcut:
@@ -689,14 +615,17 @@ def __init__(
689615

690616
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
691617
if self.interpolate:
692-
x = torch.nn.functional.interpolate(x, scale_factor=self.factor, mode=self.interpolation_mode)
618+
x = F.interpolate(hidden_states, scale_factor=self.factor, mode=self.interpolation_mode)
693619
x = self.conv(x)
694620
else:
695621
x = self.conv(hidden_states)
622+
x = F.pixel_shuffle(x, self.factor)
623+
696624
if self.shortcut is not None:
697625
hidden_states = x + self.shortcut(hidden_states)
698626
else:
699627
hidden_states = x
628+
700629
return hidden_states
701630

702631

@@ -770,8 +699,6 @@ def __init__(
770699
def forward(self, x: torch.Tensor) -> torch.Tensor:
771700
x = self.conv_in(x)
772701
for stage in self.stages:
773-
if len(stage.op_list) == 0:
774-
continue
775702
x = stage(x)
776703
x = self.conv_out(x) + self.norm_out(x)
777704
return x
@@ -858,14 +785,13 @@ def __init__(
858785
self.conv_out = DCUpBlock2d(
859786
block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
860787
in_channels,
788+
interpolate=upsample_block_type == "InterpolateConv",
861789
shortcut=False,
862790
)
863791

864792
def forward(self, x: torch.Tensor) -> torch.Tensor:
865793
x = self.conv_in(x) + self.norm_in(x)
866794
for stage in reversed(self.stages):
867-
if len(stage.op_list) == 0:
868-
continue
869795
x = stage(x)
870796
x = self.norm_out(x)
871797
x = self.conv_act(x)

0 commit comments

Comments
 (0)