Skip to content

Commit 1bcc69e

Browse files
committed
Use in_channels for depthwise groups, allows using out_channels=N * in_channels (does not impact existing models). Fix #354.
1 parent 9811e22 commit 1bcc69e

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

timm/models/layers/create_conv2d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ def create_conv2d(in_channels, out_channels, kernel_size, **kwargs):
2222
m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs)
2323
else:
2424
depthwise = kwargs.pop('depthwise', False)
25-
groups = out_channels if depthwise else kwargs.pop('groups', 1)
25+
# for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0
26+
groups = in_channels if depthwise else kwargs.pop('groups', 1)
2627
if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
2728
m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
2829
else:

timm/models/layers/mixed_conv2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(self, in_channels, out_channels, kernel_size=3,
3434
self.in_channels = sum(in_splits)
3535
self.out_channels = sum(out_splits)
3636
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
37-
conv_groups = out_ch if depthwise else 1
37+
conv_groups = in_ch if depthwise else 1
3838
# use add_module to keep key space clean
3939
self.add_module(
4040
str(idx),

0 commit comments

Comments
 (0)