Skip to content

Commit 7087954

Browse files
authored
Merge pull request #441 from FluxML/bc/dw-conv-debug
Remove threading from all `∇*conv_filter` and re-enable old tests
2 parents f5fd67e + ec54732 commit 7087954

File tree

4 files changed

+46
-36
lines changed

4 files changed

+46
-36
lines changed

src/dim_helpers/ConvDims.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,25 @@ function im2col_dims(c::ConvDims)
7777
)
7878
end
7979

80+
"""
81+
∇filter_im2col_dims(c::ConvDims)
82+
83+
Like [`im2col_dims`](@ref), but saves some memory because multiple (Julia) threads are
84+
not required for the filter gradient calculation.
85+
86+
Note: in the future, this may return `Dims{2}` instead of `Dims{3}`.
87+
"""
88+
function ∇filter_im2col_dims(c::ConvDims)
89+
return (
90+
# Output size
91+
prod(output_size(c)),
92+
# Size of single dotproduct within convolution
93+
prod(kernel_size(c))*channels_in(c),
94+
# No threading, this is just here for backwards compat
95+
1
96+
)
97+
end
98+
8099
# Protect your skin, kids. Also do common validation of stride, padding, etc...
81100
function check_spdf(x_size::NTuple{N}, w_size::NTuple{N}, stride, padding, dilation) where {N}
82101
# Number of spatial dimensions in `x` and `w`.

src/impl/conv_im2col.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,16 @@ function conv_im2col!(
6060
end
6161

6262
"""
63-
∇conv_filter_im2col!(dw, x, dy, cdims, col=similar(dw); alpha=1, beta=0)
63+
∇conv_filter_im2col!(dw, x, dy, cdims, col=similar(dw, ∇filter_im2col_dims(cdims));
64+
alpha=1, beta=0)
6465
6566
Conv backward pass onto the weights using im2col and GEMM; stores the result in `dw`.
66-
See the documentation for `conv_im2col!()` for explanation of optional parameters.
67+
See [`conv_im2col!`](@ref) for explanation of optional parameters.
6768
"""
6869
function ∇conv_filter_im2col!(
6970
dw::AbstractArray{T,5}, x::AbstractArray{T,5},
7071
dy::AbstractArray{T,5}, cdims::DenseConvDims;
71-
col::AbstractArray{T,3} = similar(dw, im2col_dims(cdims)),
72+
col::AbstractArray{T,3} = similar(dw, ∇filter_im2col_dims(cdims)),
7273
alpha::T=T(1), beta::T=T(0)) where {T}
7374
check_dims(size(x), size(dw), size(dy), cdims)
7475

@@ -115,7 +116,7 @@ end
115116
∇conv_data_im2col!(dx, w, dy, cdims, col=similar(dx); alpha=1, beta=0)
116117
117118
Conv2d backward pass onto the input using im2col and GEMM; stores the result in `dx`.
118-
See the documentation for `conv_im2col!()` for explanation of other parameters.
119+
See [`conv_im2col!`](@ref) for explanation of optional parameters.
119120
"""
120121
function ∇conv_data_im2col!(
121122
dx::AbstractArray{T,5}, dy::AbstractArray{T,5},

src/impl/depthwiseconv_im2col.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
depthwiseconv_im2col!(y, x, w, cdims, col=similar(x); alpha=1, beta=0)
66
77
Perform a depthwise convolution using im2col and GEMM, store the result in `y`.
8-
9-
See `conv_im2col!()` for an explanation of optional parameters.
8+
See [`conv_im2col!`](@ref) for explanation of optional parameters.
109
"""
1110
depthwiseconv_im2col!
1211

@@ -48,27 +47,32 @@ function depthwiseconv_im2col!(
4847
end
4948

5049
"""
51-
∇depthwiseconv_filter_im2col!(dw, w, dy, cdims, col=similar(dw); alpha=1, beta)
50+
∇depthwiseconv_filter_im2col!(dw, w, dy, cdims, col=similar(dw, ∇filter_im2col_dims(cdims));
51+
alpha=1, beta=0)
5252
53-
Depthwise conv2d backward pass onto the weights using im2col and GEMM.
54-
See the documentation for `conv_im2col!()` for explanation of optional parameters.
53+
Depthwise conv backward pass onto the weights using im2col and GEMM.
54+
See [`conv_im2col!`](@ref) for explanation of optional parameters.
5555
"""
5656
∇depthwiseconv_filter_im2col!
5757

5858
function ∇depthwiseconv_filter_im2col!(
5959
dw::AbstractArray{T,5}, x::AbstractArray{T,5},
6060
dy::AbstractArray{T,5}, cdims::DepthwiseConvDims;
61-
col::AbstractArray{T,3} = similar(dw, im2col_dims(cdims)),
61+
col::AbstractArray{T,3} = similar(dw, ∇filter_im2col_dims(cdims)),
6262
alpha::T=T(1), beta::T=T(0)) where T
6363
check_dims(size(x), size(dw), size(dy), cdims)
6464

6565
M = prod(kernel_size(cdims))
6666
N = channel_multiplier(cdims)
6767
K = prod(output_size(cdims))
6868

69-
@threads for batch_idx in 1:size(x)[end]
69+
for batch_idx in 1:size(x, 5)
70+
# Because we accumulate over batches in this loop, we must set `beta` equal
71+
# to `1.0` after the first sample.
72+
beta′ = batch_idx == 1 ? beta : T(1)
73+
7074
# col_slice is a thread-local workspace
71-
col_slice = view(col, :, :, threadid())
75+
col_slice = view(col, :, :, 1)
7276
im2col!(col_slice, view(x, :, :, :, :, batch_idx), cdims)
7377

7478
# We do a separate convolution for each channel in x, as we must
@@ -78,22 +82,18 @@ function ∇depthwiseconv_filter_im2col!(
7882
col_ptr = pointer(col_slice, (c_in - 1)*M*K + 1)
7983
dy_ptr = pointer(dy, (batch_idx - 1)*N*K*channels_in(cdims) + (c_in - 1)*K*N + 1)
8084
dw_ptr = pointer(dw, (c_in - 1)*M*N + 1)
81-
gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr)
85+
gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr)
8286
end
8387
end
84-
85-
# Because we accumulate over batches in this loop, we must set `beta` equal
86-
# to `1.0` from this point on.
87-
beta = T(1)
8888
end
8989
return dw
9090
end
9191

9292
"""
93-
depthwiseconv2d_Δx_im2col!(dx, w, dy, cdims, col=similar(dx); alpha=1, beta=0)
93+
∇depthwiseconv_data_im2col!(dx, w, dy, cdims, col=similar(dx); alpha=1, beta=0)
9494
9595
Depwthwise conv2d backward pass onto the input using im2col and GEMM.
96-
See the documentation for `conv_im2col!()` for explanation of optional parameters.
96+
See [`conv_im2col!`](@ref) for explanation of optional parameters.
9797
"""
9898
∇depthwiseconv_data_im2col!
9999

test/conv.jl

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ end
737737
end
738738

739739
# https://github.com/FluxML/NNlib.jl/pull/171
740-
@testset "conv_direct! - Check Sizes" begin
740+
@testset "conv_direct! - Check Sizes" begin
741741
x_size = (6, 7, 8, 5, 3)
742742
y_size = (5, 6, 7, 4, 3)
743743
w_size = (2, 2, 2, 5, 4)
@@ -759,25 +759,15 @@ end
759759

760760
y = conv(x, w, cdims)
761761
gradtest((y, w) -> ∇conv_data(y, w, cdims), y, w)
762-
# if spatial_rank == 3
763-
# @test_broken gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w)
764-
# else
765-
gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w)
766-
# end
767-
gradtest((x, y) -> ∇conv_filter(x, y, cdims), x, y)
768-
if spatial_rank < 3
769-
gradtest((x, y) -> sum(∇conv_filter(x, y, cdims)), x, y)
770-
end
762+
gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w)
763+
gradtest((x, y) -> ∇conv_filter(x, y, cdims), x, y)
764+
gradtest((x, y) -> sum(∇conv_filter(x, y, cdims)), x, y)
771765

772766
dcdims = DepthwiseConvDims(x, w)
773767
gradtest((x, w) -> depthwiseconv(x, w, dcdims), x, w)
774768

775769
# FIXME fails
776-
# y = depthwiseconv(x, w, dcdims)
777-
# gradtest((y, w) -> ∇depthwiseconv_data(y, w, dcdims), y, w)
778-
# if spatial_rank == 3
779-
# @test_broken gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w)
780-
# else
781-
@test_skip gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w)
782-
# end
770+
y = depthwiseconv(x, w, dcdims)
771+
gradtest((y, w) -> ∇depthwiseconv_data(y, w, dcdims), y, w)
772+
gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w)
783773
end

0 commit comments

Comments
 (0)