Skip to content

Commit 503fd23

Browse files
committed
MaxPool backprop requires knowledge of y. :(
1 parent 2503cd4 commit 503fd23

File tree

3 files changed

+22
-17
lines changed

3 files changed

+22
-17
lines changed

src/impl/pooling_direct.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ for name in (:max, :mean)
123123
# it's unfortunately different enough that I think we need a separate function. :(
124124
@eval @timeit_debug to function $((Symbol("$(name)pool_direct!")))(
125125
dx::AbstractArray{T,5}, dy::AbstractArray{T,5},
126-
x::AbstractArray{T,5}, pdims::PoolDims;
127-
alpha::T = T(1), beta::T = T(0)) where {T}
126+
y::AbstractArray{T,5}, x::AbstractArray{T,5},
127+
pdims::PoolDims; alpha::T = T(1), beta::T = T(0)) where {T}
128128
check_dims(size(x), size(dy), pdims)
129129

130130
width, height, depth = input_size(pdims)
@@ -156,7 +156,8 @@ for name in (:max, :mean)
156156
h in h_region,
157157
w in w_region
158158

159-
# Grab the incoming gradient at this index for future use
159+
# Grab the output at this index for future use
160+
y_idx = y[w, h, d, c, batch_idx]
160161
dy_idx = dy[w, h, d, c, batch_idx]
161162
maxpool_already_chose = false
162163

@@ -174,7 +175,7 @@ for name in (:max, :mean)
174175
if $(name == :max)
175176
# If it's equal; this is the one we chose. We only choose one per
176177
# kernel window, all other elements of dx must be zero.
177-
if dy_idx == x[x_idxs...] && !maxpool_already_chose
178+
if y_idx == x[x_idxs...] && !maxpool_already_chose
178179
dx[x_idxs...] = dy_idx*alpha + beta*dx[x_idxs...]
179180
maxpool_already_chose = true
180181
# Maxpooling does not support `beta` right now. :(
@@ -199,6 +200,7 @@ for name in (:max, :mean)
199200
w in w_region
200201

201202
# Grab the incoming gradient at this index for future use
203+
y_idx = dy[w, h, d, c, batch_idx]
202204
dy_idx = dy[w, h, d, c, batch_idx]
203205
maxpool_already_chose = false
204206

@@ -225,7 +227,7 @@ for name in (:max, :mean)
225227
# Same as above
226228
x_idxs = (input_kw, input_kh, input_kd, c, batch_idx)
227229
if $(name == :max)
228-
if dy_idx == x[x_idxs...] && !maxpool_already_chose
230+
if y_idx == x[x_idxs...] && !maxpool_already_chose
229231
dx[x_idxs...] = dy_idx*alpha + beta*dx[x_idxs...]
230232
maxpool_already_chose = true
231233
#else

src/pooling.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,10 @@ for (front_name, backend) in (
4747
)
4848
@eval begin
4949
function $(Symbol("$(front_name)!"))(
50-
dx::AbstractArray{T,5}, dy::AbstractArray{T,5},
51-
x::AbstractArray{T,5}, pdims::PoolDims; kwargs...) where {T}
52-
$(Symbol("$(front_name)_$(backend)!"))(dx, dy, x, pdims; kwargs...)
50+
dx::AbstractArray{T,5}, dy::AbstractArray{T,5},
51+
y::AbstractArray{T,5}, x::AbstractArray{T,5},
52+
pdims::PoolDims; kwargs...) where {T}
53+
$(Symbol("$(front_name)_$(backend)!"))(dx, dy, y, x, pdims; kwargs...)
5354
end
5455
end
5556
end
@@ -79,12 +80,13 @@ for front_name in (:maxpool, :meanpool)
7980

8081
# backprops too
8182
function $(Symbol("$(front_name)$(backend)!"))(
82-
dx::AbstractArray{T,$N}, dy::AbstractArray{T,$N},
83-
x::AbstractArray{T,$N}, pdims::PoolDims;
84-
kwargs...) where {T}
83+
dx::AbstractArray{T,$N}, dy::AbstractArray{T,$N},
84+
y::AbstractArray{T,$N}, x::AbstractArray{T,$N},
85+
pdims::PoolDims; kwargs...) where {T}
8586
$(Symbol("$(front_name)$(backend)!"))(
8687
insert_singleton_spatial_dimension(dx, $(5 - N)),
8788
insert_singleton_spatial_dimension(dy, $(5 - N)),
89+
insert_singleton_spatial_dimension(y, $(5 - N)),
8890
insert_singleton_spatial_dimension(x, $(5 - N)),
8991
insert_singleton_spatial_dimension(pdims, $(5 - N));
9092
kwargs...
@@ -114,10 +116,11 @@ for backend in (Symbol(), :_direct, :_im2col)
114116

115117
# Backprops too
116118
@timeit_debug to function $(Symbol("$(name)$(backend)"))(
117-
dy::AbstractArray{T,N}, x::AbstractArray{T},
118-
pdims::PoolDims; kwargs...) where {T, N}
119+
dy::AbstractArray{T,N}, y::AbstractArray{T,N},
120+
x::AbstractArray{T,N}, pdims::PoolDims;
121+
kwargs...) where {T, N}
119122
dx = zeros(T, input_size(pdims)..., channels_in(pdims), size(dy, N))
120-
return $(Symbol("$(name)$(backend)!"))(dx, dy, x, pdims; kwargs...)
123+
return $(Symbol("$(name)$(backend)!"))(dx, dy, y, x, pdims; kwargs...)
121124
end
122125
end
123126
end

test/pooling.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,19 +280,19 @@ for rank in (1, 2, 3)
280280
pdims = PoolDims(x, 2)
281281
y_hat = pool(x, pdims)
282282
@test ddims(y_hat) == y
283-
@test ddims(∇pool(y_hat, x, pdims)) == dx
283+
@test ddims(∇pool(y_hat, y_hat, x, pdims)) == dx
284284

285285
# Strided pooling
286286
pdims = PoolDims(x, 2; stride=1)
287287
y_hat = pool(x, pdims)
288288
@test ddims(y_hat) == y_nostride
289-
@test ddims(∇pool(y_hat, x, pdims)) == dx_nostride
289+
@test ddims(∇pool(y_hat, y_hat, x, pdims)) == dx_nostride
290290

291291
# Padded pooling
292292
pdims = PoolDims(x, 2; padding=1)
293293
y_hat = pool(x, pdims)
294294
@test ddims(y_hat) == y_pad
295-
@test ddims(∇pool(y_hat, x, pdims)) == dx_pad
295+
@test ddims(∇pool(y_hat, y_hat, x, pdims)) == dx_pad
296296
end
297297
end
298298
end

0 commit comments

Comments
 (0)