Skip to content

Commit 4ea355e

Browse files
authored
Added wrappers for conv and depthwiseconv (#127)
Added wrappers for conv and depthwiseconv
2 parents 77866d5 + 3ac9933 commit 4ea355e

File tree

2 files changed

+48
-8
lines changed

2 files changed

+48
-8
lines changed

src/conv.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter!
1+
export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter!, depthwiseconv,
2+
depthwiseconv!, ∇depthwiseconv_data, ∇depthwiseconv_data!, ∇depthwiseconv_filter,
3+
∇depthwiseconv_filter!
24

35
## Convolution API
46
#
@@ -161,3 +163,19 @@ if is_nnpack_available()
161163
return conv_nnpack(x, w, cdims; kwargs...)
162164
end
163165
end
166+
167+
function conv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flipped = false) where {T, N}
168+
stride = expand(Val(N-2), stride)
169+
pad = expand(Val(N-2), pad)
170+
dilation = expand(Val(N-2), dilation)
171+
cdims = DenseConvDims(x, w; stride = stride, padding = pad, dilation = dilation, flipkernel = flipped)
172+
return conv(x, w, cdims)
173+
end
174+
175+
function depthwiseconv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flipped = false) where {T, N}
176+
stride = expand(Val(N-2), stride)
177+
pad = expand(Val(N-2), pad)
178+
dilation = expand(Val(N-2), dilation)
179+
cdims = DepthwiseConvDims(x, w; stride = stride, padding = pad, dilation = dilation, flipkernel = flipped)
180+
return depthwiseconv(x, w, cdims)
181+
end

test/conv.jl

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ using NNlib: input_size, kernel_size, channels_in, channels_out, channel_multipl
1212
elseif T == DepthwiseConvDims
1313
w = randn(1,2,4,3)
1414
end
15-
15+
1616
# First, getters:
1717
cdims = T(x, w)
1818
@test input_size(cdims) == size(x)[1:2]
@@ -39,7 +39,7 @@ using NNlib: input_size, kernel_size, channels_in, channels_out, channel_multipl
3939
@test padding(cdims) == (3,3,3,3)
4040
@test flipkernel(cdims) == true
4141
@test output_size(cdims) == (6,4)
42-
42+
4343
# Next, tuple settings
4444
cdims = T(x, w; stride=(1, 2), dilation=(1, 2), padding=(0,1))
4545
@test stride(cdims) == (1,2)
@@ -215,7 +215,7 @@ conv_answer_dict = Dict(
215215
"dx_dil" => reshape([
216216
4864, 5152, 9696, 4508, 4760, 6304, 6592, 12396, 5768, 6020, 3648,
217217
3864, 7120, 3220, 3400, 4728, 4944, 9100, 4120, 4300, 0, 0, 0, 0, 0, 0,
218-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2432, 2576, 4544, 1932, 2040,
218+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2432, 2576, 4544, 1932, 2040,
219219
3152, 3296, 5804, 2472, 2580, 1216, 1288, 1968, 644, 680, 1576, 1648,
220220
2508, 824, 860.
221221
], (5,4,3)),
@@ -273,7 +273,7 @@ conv_answer_dict = Dict(
273273

274274
# A "drop channels and batch dimension" helper
275275
ddims(x) = dropdims(x, dims=(rank+1, rank+2))
276-
276+
277277
for conv in (NNlib.conv, NNlib.conv_im2col, NNlib.conv_direct)
278278
@testset "$(conv)" begin
279279
# First, your basic convolution with no parameters
@@ -392,7 +392,7 @@ conv_answer_dict = Dict(
392392
D_size in (1, 2, 4, (1,2), (3,2), (4,2,3))
393393

394394
# Skip tests that are impossible due to mismatched sizes
395-
try
395+
try
396396
DenseConvDims(x, w;
397397
stride=S_size, padding=P_size, dilation=D_size,
398398
)
@@ -473,7 +473,7 @@ end
473473

474474
# A "drop channels and batch dimension" helper
475475
ddims(x) = dropdims(x, dims=(rank+1, rank+2))
476-
476+
477477
for conv in (NNlib.depthwiseconv, NNlib.depthwiseconv_im2col, NNlib.depthwiseconv_direct)
478478
@testset "$(conv)" begin
479479
# First, your basic convolution with no parameters
@@ -592,7 +592,7 @@ end
592592
D_size in (1, 2, 4, (1,2), (3,2), (4,2,3))
593593

594594
# Skip tests that are impossible due to mismatched sizes
595-
try
595+
try
596596
DepthwiseConvDims(x, w;
597597
stride=S_size, padding=P_size, dilation=D_size,
598598
)
@@ -639,3 +639,25 @@ end
639639
println()
640640
end
641641
end
642+
643+
@testset "conv_wrapper" begin
644+
x = rand(10, 10, 3, 10)
645+
w = rand(2, 2, 3, 16)
646+
w1 = rand(3, 4, 3, 16)
647+
@test size(conv(x, w)) == (9, 9, 16, 10)
648+
@test size(conv(x, w; stride = (2, 2), pad = (2, 2))) == (7, 7, 16, 10)
649+
@test size(conv(x, w1; stride = (1, 2), pad = (2, 3))) == (12, 7, 16, 10)
650+
@test size(conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2))) == (12, 7, 16, 10)
651+
@test size(conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2), flipped = true)) == (12, 7, 16, 10)
652+
end
653+
654+
@testset "depthwiseconv_wrapper" begin
655+
x = rand(10, 10, 3, 10)
656+
w = rand(2, 2, 3, 3)
657+
w1 = rand(3, 4, 3, 3)
658+
@test size(depthwiseconv(x, w)) == (9, 9, 9, 10)
659+
@test size(depthwiseconv(x, w; stride = (2, 2), pad = (2, 2))) == (7, 7, 9, 10)
660+
@test size(depthwiseconv(x, w1; stride = (1, 2), pad = (2, 3))) == (12, 7, 9, 10)
661+
@test size(depthwiseconv(x, w1; stride = (1, 2), pad = (2, 3), dilation = (2, 2))) == (10, 5, 9, 10)
662+
@test size(depthwiseconv(x, w1; stride = (1, 2), pad = (2, 3), dilation = (2, 2), flipped = true)) == (10, 5, 9, 10)
663+
end

0 commit comments

Comments
 (0)