Skip to content

Commit 4d3c026

Browse files
committed
change file name to autoencoder_dc
1 parent 2f6bbad commit 4d3c026

File tree

1 file changed

+21
-27
lines changed

1 file changed

+21
-27
lines changed

src/diffusers/models/autoencoders/dc_ae.py renamed to src/diffusers/models/autoencoders/autoencoder_dc.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,15 @@ def __init__(
153153
)
154154

155155
def forward(self, x: torch.Tensor) -> torch.Tensor:
156-
x = self.inverted_conv(x)
157-
x = self.depth_conv(x)
156+
y = self.inverted_conv(x)
157+
y = self.depth_conv(y)
158158

159-
x, gate = torch.chunk(x, 2, dim=1)
159+
y, gate = torch.chunk(y, 2, dim=1)
160160
gate = self.glu_act(gate)
161-
x = x * gate
161+
y = y * gate
162162

163-
x = self.point_conv(x)
164-
return x
163+
y = self.point_conv(y)
164+
return x + y
165165

166166

167167
class ResBlock(nn.Module):
@@ -349,7 +349,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
349349
out = self.relu_quadratic_att(qkv)
350350
out = self.proj(out)
351351

352-
return out
352+
return x + out
353353

354354

355355
class EfficientViTBlock(nn.Module):
@@ -367,30 +367,24 @@ def __init__(
367367
):
368368
super().__init__()
369369
if context_module == "LiteMLA":
370-
self.context_module = ResidualBlock(
371-
LiteMLA(
372-
in_channels=in_channels,
373-
out_channels=in_channels,
374-
heads_ratio=heads_ratio,
375-
dim=dim,
376-
norm=(None, norm),
377-
scales=scales,
378-
),
379-
nn.Identity(),
370+
self.context_module = LiteMLA(
371+
in_channels=in_channels,
372+
out_channels=in_channels,
373+
heads_ratio=heads_ratio,
374+
dim=dim,
375+
norm=(None, norm),
376+
scales=scales,
380377
)
381378
else:
382379
raise ValueError(f"context_module {context_module} is not supported")
383380
if local_module == "GLUMBConv":
384-
self.local_module = ResidualBlock(
385-
GLUMBConv(
386-
in_channels=in_channels,
387-
out_channels=in_channels,
388-
expand_ratio=expand_ratio,
389-
use_bias=(True, True, False),
390-
norm=(None, None, norm),
391-
act_func=(act_func, act_func, None),
392-
),
393-
nn.Identity(),
381+
self.local_module = GLUMBConv(
382+
in_channels=in_channels,
383+
out_channels=in_channels,
384+
expand_ratio=expand_ratio,
385+
use_bias=(True, True, False),
386+
norm=(None, None, norm),
387+
act_func=(act_func, act_func, None),
394388
)
395389
else:
396390
raise NotImplementedError(f"local_module {local_module} is not supported")

0 commit comments

Comments
 (0)