Skip to content

Commit 6e89a3a

Browse files
committed
Added support for grouped convolutions (and therefore, depth-wise separable convolutions).
1 parent 71ef0c4 commit 6e89a3a

File tree

2 files changed

+44
-22
lines changed

2 files changed

+44
-22
lines changed

benchmark.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,23 @@ def benchmark_conv(
3434
bias: Tensor,
3535
padding: Union[int, Iterable[int]] = 0,
3636
stride: Union[int, Iterable[int]] = 1,
37+
groups: int = 1,
3738
):
3839
print(f"Signal size: {signal.shape}")
3940
print(f"Kernel size: {kernel.shape}")
4041

4142
torch_conv = {1: f.conv1d, 2: f.conv2d, 3: f.conv3d}[signal.ndim - 2]
4243
direct_time = benchmark(
43-
torch_conv, signal, kernel, bias=bias, padding=padding, stride=stride
44+
torch_conv, signal, kernel, bias=bias, padding=padding, stride=stride, groups=groups
4445
)
4546
fourier_time = benchmark(
46-
fft_conv, signal, kernel, bias=bias, padding=padding, stride=stride
47+
fft_conv, signal, kernel, bias=bias, padding=padding, stride=stride, groups=groups
4748
)
4849
print(f"Direct time: {direct_time}")
4950
print(f"Fourier time: {fourier_time}")
5051

51-
y0 = torch_conv(signal, kernel, bias=bias, padding=padding, stride=stride)
52-
y1 = fft_conv(signal, kernel, bias=bias, padding=padding, stride=stride)
52+
y0 = torch_conv(signal, kernel, bias=bias, padding=padding, stride=stride, groups=groups)
53+
y1 = fft_conv(signal, kernel, bias=bias, padding=padding, stride=stride, groups=groups)
5354
abs_error = torch.abs(y0 - y1)
5455
print(f"Output size: {y0.size()}")
5556
print(f"Abs Error Mean: {abs_error.mean():.3E}")
@@ -58,27 +59,30 @@ def benchmark_conv(
5859

5960
print("\n--- 1D Convolution ---")
6061
benchmark_conv(
61-
signal=torch.randn(3, 3, 4091),
62-
kernel=torch.randn(2, 3, 1025),
63-
bias=torch.randn(2),
62+
signal=torch.randn(4, 4, 4091),
63+
kernel=torch.randn(6, 2, 1025),
64+
bias=torch.randn(6),
6465
padding=512,
6566
stride=3,
67+
groups=2,
6668
)
6769

6870
print("\n--- 2D Convolution ---")
6971
benchmark_conv(
70-
signal=torch.randn(3, 3, 256, 235),
71-
kernel=torch.randn(2, 3, 19, 21),
72-
bias=torch.randn(2),
72+
signal=torch.randn(4, 4, 256, 235),
73+
kernel=torch.randn(6, 2, 22, 21),
74+
bias=torch.randn(6),
7375
padding=(9, 10),
7476
stride=(2, 3),
77+
groups=2,
7578
)
7679

7780
print("\n--- 3D Convolution ---")
7881
benchmark_conv(
79-
signal=torch.randn(3, 3, 64, 72, 61),
80-
kernel=torch.randn(2, 3, 5, 7, 9),
81-
bias=torch.randn(2),
82+
signal=torch.randn(4, 4, 96, 72, 61),
83+
kernel=torch.randn(6, 2, 12, 7, 9),
84+
bias=torch.randn(6),
8285
padding=(2, 3, 4),
83-
stride=(1, 2, 3)
86+
stride=(1, 2, 3),
87+
groups=2,
8488
)

fft_conv.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,15 @@
77
import torch.nn.functional as f
88

99

10-
def complex_matmul(a: Tensor, b: Tensor) -> Tensor:
10+
def complex_matmul(a: Tensor, b: Tensor, groups: int = 1) -> Tensor:
1111
"""Multiplies two complex-valued tensors."""
12-
# Scalar matrix multiplication of two tensors, over only the first two dimensions.
13-
# Dimensions 3 and higher will have the same shape after multiplication.
14-
scalar_matmul = partial(torch.einsum, "ab..., cb... -> ac...")
12+
# Scalar matrix multiplication of two tensors, over only the first channel
13+
# dimensions. Dimensions 3 and higher will have the same shape after multiplication.
14+
# We also allow for "grouped" multiplications, where multiple sections of channels
15+
# are multiplied independently of one another (required for group convolutions).
16+
scalar_matmul = partial(torch.einsum, "agc..., gbc... -> agb...")
17+
a = a.view(a.size(0), groups, -1, *a.shape[2:])
18+
b = b.view(groups, -1, *b.shape[1:])
1519

1620
# Compute the real and imaginary parts independently, then manually insert them
1721
# into the output Tensor. This is fairly hacky but necessary for PyTorch 1.7.0,
@@ -22,7 +26,7 @@ def complex_matmul(a: Tensor, b: Tensor) -> Tensor:
2226
c = torch.zeros(real.shape, dtype=torch.complex64, device=a.device)
2327
c.real, c.imag = real, imag
2428

25-
return c
29+
return c.view(c.size(0), -1, *c.shape[3:])
2630

2731

2832
def to_ntuple(val: Union[int, Iterable[int]], n: int) -> Tuple[int, ...]:
@@ -52,6 +56,7 @@ def fft_conv(
5256
bias: Tensor = None,
5357
padding: Union[int, Iterable[int]] = 0,
5458
stride: Union[int, Iterable[int]] = 1,
59+
groups: int = 1,
5560
) -> Tensor:
5661
"""Performs N-d convolution of Tensors using a fast fourier transform, which
5762
is very fast for large kernel sizes. Also, optionally adds a bias Tensor after
@@ -78,7 +83,7 @@ def fft_conv(
7883

7984
# Because PyTorch computes a *one-sided* FFT, we need the final dimension to
8085
# have *even* length. Just pad with one more zero if the final dimension is odd.
81-
if signal.size(-1) % 2:
86+
if signal.size(-1) % 2 != 0:
8287
signal_ = f.pad(signal, [0, 1])
8388
else:
8489
signal_ = signal
@@ -91,11 +96,12 @@ def fft_conv(
9196
padded_kernel = f.pad(kernel, kernel_padding)
9297

9398
# Perform fourier convolution -- FFT, matrix multiply, then IFFT
99+
# signal_ = signal_.reshape(signal_.size(0), groups, -1, *signal_.shape[2:])
94100
signal_fr = rfftn(signal_, dim=tuple(range(2, signal.ndim)))
95101
kernel_fr = rfftn(padded_kernel, dim=tuple(range(2, signal.ndim)))
96102

97103
kernel_fr.imag *= -1
98-
output_fr = complex_matmul(signal_fr, kernel_fr)
104+
output_fr = complex_matmul(signal_fr, kernel_fr, groups=groups)
99105
output = irfftn(output_fr, dim=tuple(range(2, signal.ndim)))
100106

101107
# Remove extra padded values
@@ -123,6 +129,7 @@ def __init__(
123129
kernel_size: Union[int, Iterable[int]],
124130
padding: Union[int, Iterable[int]] = 0,
125131
stride: Union[int, Iterable[int]] = 1,
132+
groups: int = 1,
126133
bias: bool = True,
127134
ndim: int = 1,
128135
):
@@ -142,10 +149,20 @@ def __init__(
142149
self.kernel_size = kernel_size
143150
self.padding = padding
144151
self.stride = stride
152+
self.groups = groups
145153
self.use_bias = bias
146154

155+
if in_channels % 2 != 0:
156+
raise ValueError(
157+
f"'in_channels' ({in_channels}) must be divisible by 'groups' ({groups})."
158+
)
159+
if out_channels % 2 != 0:
160+
raise ValueError(
161+
f"'out_channels' ({out_channels}) must be divisible by 'groups' ({groups})."
162+
)
163+
147164
kernel_size = to_ntuple(kernel_size, ndim)
148-
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, *kernel_size))
165+
self.weight = nn.Parameter(torch.randn(out_channels, in_channels // groups, *kernel_size))
149166
self.bias = nn.Parameter(torch.randn(out_channels,)) if bias else None
150167

151168
def forward(self, signal):
@@ -155,6 +172,7 @@ def forward(self, signal):
155172
bias=self.bias,
156173
padding=self.padding,
157174
stride=self.stride,
175+
groups=self.groups,
158176
)
159177

160178

0 commit comments

Comments
 (0)