Skip to content

Commit d4cbdb3

Browse files
committed
fix maxpool gradient
1 parent c8b7661 commit d4cbdb3

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/impl/pooling_direct.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,15 +176,15 @@ for name in (:max, :mean)
176176
# If it's equal; this is the one we chose. We only choose one per
177177
# kernel window, all other elements of dx must be zero.
178178
if y_idx == x[x_idxs...] && !maxpool_already_chose
179-
dx[x_idxs...] = dy_idx*alpha + beta*dx[x_idxs...]
179+
dx[x_idxs...] += dy_idx*alpha + beta*dx[x_idxs...]
180180
maxpool_already_chose = true
181181
# Maxpooling does not support `beta` right now. :(
182182
#else
183183
# dx[x_idxs...] = T(0) + beta*dx[x_idxs...]
184184
end
185185
elseif $(name == :mean)
186186
# Either does meanpool :(
187-
dx[x_idxs...] = dy_idx*alpha + dx[x_idxs...]
187+
dx[x_idxs...] += dy_idx*alpha + dx[x_idxs...]
188188
else
189189
error("Unimplemented codegen path")
190190
end
@@ -228,7 +228,7 @@ for name in (:max, :mean)
228228
x_idxs = (input_kw, input_kh, input_kd, c, batch_idx)
229229
if $(name == :max)
230230
if y_idx == x[x_idxs...] && !maxpool_already_chose
231-
dx[x_idxs...] = dy_idx*alpha + beta*dx[x_idxs...]
231+
dx[x_idxs...] += dy_idx*alpha + beta*dx[x_idxs...]
232232
maxpool_already_chose = true
233233
#else
234234
# dx[x_idxs...] = T(0) + beta*dx[x_idxs...]

0 commit comments

Comments
 (0)