Skip to content

Commit ea604a4

Browse files
committed
update ResBlock
1 parent 59de0a3 commit ea604a4

File tree

1 file changed

+24
-22
lines changed
  • src/diffusers/models/autoencoders

1 file changed

+24
-22
lines changed

src/diffusers/models/autoencoders/dc_ae.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
from typing import Any, Optional, Callable, Union
17+
from collections import OrderedDict
1718

1819
import torch
1920
import torch.nn as nn
@@ -169,28 +170,30 @@ def __init__(
169170
super().__init__()
170171
mid_channels = round(in_channels * expand_ratio) if mid_channels is None else mid_channels
171172

172-
self.conv1 = ConvLayer(
173-
in_channels,
174-
mid_channels,
175-
kernel_size,
176-
stride,
177-
use_bias=use_bias[0],
178-
norm=norm[0],
179-
act_func=act_func[0],
180-
)
181-
self.conv2 = ConvLayer(
182-
mid_channels,
183-
out_channels,
184-
kernel_size,
185-
1,
186-
use_bias=use_bias[1],
187-
norm=norm[1],
188-
act_func=act_func[1],
189-
)
173+
self.main = nn.Sequential(OrderedDict([
174+
("conv1", ConvLayer(
175+
in_channels,
176+
mid_channels,
177+
kernel_size,
178+
stride,
179+
use_bias=use_bias[0],
180+
norm=norm[0],
181+
act_func=act_func[0],
182+
)),
183+
("conv2", ConvLayer(
184+
mid_channels,
185+
out_channels,
186+
kernel_size,
187+
1,
188+
use_bias=use_bias[1],
189+
norm=norm[1],
190+
act_func=act_func[1],
191+
)),
192+
]))
193+
self.shortcut = nn.Identity()
190194

191195
def forward(self, x: torch.Tensor) -> torch.Tensor:
192-
x = self.conv1(x)
193-
x = self.conv2(x)
196+
x = self.main(x) + self.shortcut(x)
194197
return x
195198

196199

@@ -448,7 +451,7 @@ def build_block(
448451
) -> nn.Module:
449452
if block_type == "ResBlock":
450453
assert in_channels == out_channels
451-
main_block = ResBlock(
454+
block = ResBlock(
452455
in_channels=in_channels,
453456
out_channels=out_channels,
454457
kernel_size=3,
@@ -457,7 +460,6 @@ def build_block(
457460
norm=(None, norm),
458461
act_func=(act, None),
459462
)
460-
block = ResidualBlock(main_block, nn.Identity())
461463
elif block_type == "EViTGLU":
462464
assert in_channels == out_channels
463465
block = EfficientViTBlock(in_channels, norm=norm, act_func=act, local_module="GLUMBConv", scales=())

0 commit comments

Comments
 (0)