Skip to content

Commit 8da76bd

Browse files
authored
Fix grad conv im2col (#539)
* Fix grad conv im2col * Also fix depthwise * Enable prevously broken tests * Revert "Enable prevously broken tests" This reverts commit d648fdd. * Add explicit im2col test * Fix and test third case * More tests now pass
1 parent 37d9a02 commit 8da76bd

File tree

3 files changed

+28
-13
lines changed

3 files changed

+28
-13
lines changed

src/impl/conv_im2col.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ function ∇conv_data_im2col!(
162162
col_ptr = pointer(col_slice)
163163
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
164164
end
165-
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims)
165+
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims, beta)
166166
end
167167
end
168168
end
@@ -276,7 +276,7 @@ end
276276

277277

278278
"""
279-
col2im!(x, col, cdims)
279+
col2im!(x, col, cdims, beta=0)
280280
281281
Does the inverse of `im2col!()`, converting `col` back into a 3d image, used for backward
282282
passes, transposed convolutions, etc...
@@ -287,7 +287,7 @@ desperate enough yet.
287287
"""
288288
col2im!
289289

290-
function col2im!(x::AbstractArray{T,4}, col::AbstractArray{T,2}, cdims::ConvDims) where T
290+
function col2im!(x::AbstractArray{T,4}, col::AbstractArray{T,2}, cdims::ConvDims, beta::T=T(0)) where T
291291
if spatial_dims(cdims) != 3
292292
throw(DimensionMismatch("col2im!() only accepts 3d convoluitional inputs"))
293293
end
@@ -303,7 +303,13 @@ function col2im!(x::AbstractArray{T,4}, col::AbstractArray{T,2}, cdims::ConvDims
303303

304304
# TODO: Rewrite this method so we don't have this fill!() at the beginning!
305305
# Calculate each output pixel once rather than accumulating into it?
306-
fill!(x, T(0))
306+
if beta == T(0)
307+
fill!(x, T(0))
308+
elseif beta == T(1)
309+
# nothing
310+
else
311+
x .*= beta
312+
end
307313

308314
# Reshape col for easy access.
309315
col_reshaped = reshape(col, (

src/impl/depthwiseconv_im2col.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ function ∇depthwiseconv_data_im2col!(
131131
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
132132
end
133133
end
134-
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims)
134+
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims, beta)
135135
end
136136
end
137137
end

test/conv.jl

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,16 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x)))
399399
end
400400
end
401401

402+
# Test im2col
403+
404+
for beta in (-2.0, -1.0, 0.0, 0.5, 1.0, 2.0)
405+
cache_dx, cache_dy, cache_w = ([0.17;;; 0.19;;; 0.23], [0.11;;; 0.13;;; 0.15], [1.0;;;])
406+
dx_old = copy(cache_dx)
407+
cdims = DenseConvDims(cache_dx, cache_w)
408+
NNlib.∇conv_data_im2col!(cache_dx, cache_dy, cache_w, cdims; alpha=1.0, beta)
409+
@test isapprox(cache_dx, dx_old * beta + cache_dy, rtol = 1.0e-7)
410+
end
411+
402412
# Test all in-place implementations/interfaces
403413
for (∇conv_filter!, ∇conv_data!) in (
404414
(NNlib.∇conv_filter!, NNlib.∇conv_data!),
@@ -407,47 +417,46 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x)))
407417
)
408418
#α, β = 2*rand(rng) - 1, 2*rand(rng) - 1
409419
α, β = 2e0, -1e0
410-
flag = ∇conv_data! in (NNlib.∇conv_data!, NNlib.∇conv_data_im2col!)
411420

412421
@testset "$(∇conv_filter!)/$(∇conv_data!)" begin
413422
# First, your basic convolution with no parameters
414423
cdims = DenseConvDims(x, w)
415424
dy = NNlib.conv(x, w, cdims)
416425
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw + β*w, rtol = 1.0e-7)
417-
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7) broken=flag
426+
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7)
418427

419428
# Next, test convolution on views and alternate datatypes:
420429
@test isapprox(ddims(∇conv_filter!(copy(w), x, view(dy, repeat([:], ndims(dy))...), cdims; alpha=α, beta=β)), α*dw + β*w, rtol = 1.0e-7)
421-
@test isapprox(ddims(∇conv_data!(copy(x), view(dy, repeat([:], ndims(dy))...), w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7) broken=flag
430+
@test isapprox(ddims(∇conv_data!(copy(x), view(dy, repeat([:], ndims(dy))...), w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7)
422431

423432
@test isapprox(ddims(∇conv_filter!(Float32.(copy(w)), Float32.(x), Float32.(dy), cdims; alpha=Float32(α), beta=Float32(β))), α*dw + β*w, rtol = 1.0e-7)
424-
@test isapprox(ddims(∇conv_data!(Float32.(copy(x)), Float32.(dy), Float32.(w), cdims; alpha=Float32(α), beta=Float32(β))), α*dx + β*x, rtol = 1.0e-7) broken=flag
433+
@test isapprox(ddims(∇conv_data!(Float32.(copy(x)), Float32.(dy), Float32.(w), cdims; alpha=Float32(α), beta=Float32(β))), α*dx + β*x, rtol = 1.0e-7)
425434

426435
# Next, introduce stride:
427436
cdims = DenseConvDims(x, w; stride=2)
428437
dy = NNlib.conv(x, w, cdims)
429438
flag_ = ∇conv_filter! == NNlib.∇conv_filter_direct! && rank in (1,3)
430439
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_stride + β*w, rtol = 1.0e-7) broken=flag_
431-
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_stride + β*x, rtol = 1.0e-7) broken=flag
440+
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_stride + β*x, rtol = 1.0e-7)
432441

433442
# Next, introduce dilation:
434443
cdims = DenseConvDims(x, w; dilation=2)
435444
dy = NNlib.conv(x, w, cdims)
436445
flag_ = ∇conv_data! == NNlib.∇conv_data_direct! && rank == 3
437446
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_dil + β*w, rtol = 1.0e-7)
438-
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_dil + β*x, rtol = 1.0e-7) broken=flag || flag_
447+
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_dil + β*x, rtol = 1.0e-7) broken=flag_
439448

440449
# Next, introduce padding:
441450
cdims = DenseConvDims(x, w; padding=1)
442451
dy = NNlib.conv(x, w, cdims)
443452
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_pad + β*w, rtol = 1.0e-7)
444-
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_pad + β*x, rtol = 1.0e-7) broken=flag
453+
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_pad + β*x, rtol = 1.0e-7)
445454

446455
# Next, test crosscor/conv with a flipped kernel
447456
cdims = DenseConvDims(x, w; flipkernel=true)
448457
dy = NNlib.conv(x, w, cdims)
449458
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_flip + β*w, rtol = 1.0e-7)
450-
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_flip + β*x, rtol = 1.0e-7) broken=flag
459+
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_flip + β*x, rtol = 1.0e-7)
451460
end
452461
end
453462
end

0 commit comments

Comments
 (0)