Skip to content

Commit 8b739fd

Browse files
committed
add a finite difference test
1 parent 924f472 commit 8b739fd

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

test/pooling.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,3 +316,33 @@ end
316316
dx = ∇maxpool(dy, y, x, pdims)
317317
@test dx[:,:,1,1] == [1.0 0.0 1.0; 0.0 0.0 0.0; 1.0 0.0 1.0]
318318
end
319+
320+
# test "true" strided case, see https://github.com/FluxML/NNlib.jl/issues/205
321+
322+
@testset "Issus #205" begin
323+
x = [
324+
0.0299635 0.233456 0.596161 0.161514 0.0094027;
325+
0.389984 0.235158 0.579525 0.301893 0.561358;
326+
0.0830242 0.483759 0.914904 0.253871 0.820061;
327+
0.425287 0.53451 0.0405225 0.729861 0.403925;
328+
0.473724 0.571418 0.558427 0.552183 0.561624;
329+
]
330+
331+
dx_ans = [
332+
0.0 0.0 2.0 0.0 0.0;
333+
1.0 0.0 0.0 0.0 1.0;
334+
0.0 1.0 4.0 0.0 2.0;
335+
0.0 1.0 0.0 2.0 0.0;
336+
0.0 2.0 0.0 0.0 0.0;
337+
]
338+
339+
x = reshape(x, 5, 5, 1, 1)
340+
dx_ans = reshape(dx_ans, 5, 5, 1, 1)
341+
dy = ones(4,4,1,1)
342+
pdims = PoolDims(x, 2; stride=1, padding=0)
343+
344+
y = maxpool(x, (2,2), pad=0, stride=1)
345+
dx = ∇maxpool(dy, y, x, pdims)
346+
347+
@test dx dx_ans
348+
end

0 commit comments

Comments
 (0)