Skip to content

Commit 9e95671

Browse files
authored
Fix non-symmetric padding (#595)
1 parent f87cf6e commit 9e95671

File tree

3 files changed

+74
-42
lines changed

3 files changed

+74
-42
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1717
[weakdeps]
1818
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
1919
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
20-
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
2120
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
2221
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
22+
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
2323

2424
[extensions]
2525
NNlibAMDGPUExt = "AMDGPU"
@@ -33,7 +33,6 @@ AMDGPU = "0.9.4"
3333
Adapt = "3.2, 4"
3434
Atomix = "0.1"
3535
CUDA = "4, 5"
36-
cuDNN = "1"
3736
ChainRulesCore = "1.13"
3837
EnzymeCore = "0.5, 0.6, 0.7"
3938
FFTW = "1.8.0"
@@ -44,4 +43,5 @@ Pkg = "<0.0.1, 1"
4443
Random = "<0.0.1, 1"
4544
Requires = "1.0"
4645
Statistics = "1"
46+
cuDNN = "1"
4747
julia = "1.9"

src/padding.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,12 @@ function pad_reflect(
270270
) where {F,N}
271271
lpad, rpad = pad
272272
n = size(x, dims)
273-
xl = lpad == 0 ? similar(x, 0) : reverse(selectdim(x, dims, 2:lpad+1); dims)
274-
xr = rpad == 0 ? similar(x, 0) : reverse(selectdim(x, dims, n-rpad:n-1); dims)
273+
xl = lpad == 0 ?
274+
similar(x, ntuple(i -> i == dims ? 0 : size(x, i), ndims(x))) :
275+
reverse(selectdim(x, dims, 2:lpad+1); dims)
276+
xr = rpad == 0 ?
277+
similar(x, ntuple(i -> i == dims ? 0 : size(x, i), ndims(x))) :
278+
reverse(selectdim(x, dims, n-rpad:n-1); dims)
275279
return cat(xl, x, xr; dims)
276280
end
277281

@@ -326,8 +330,12 @@ function pad_symmetric(
326330
lpad, rpad = pad
327331
n = size(x, dims)
328332

329-
xl = lpad == 0 ? similar(x, 0) : reverse(selectdim(x, dims, 1:lpad); dims)
330-
xr = rpad == 0 ? similar(x, 0) : reverse(selectdim(x, dims, n-rpad+1:n); dims)
333+
xl = lpad == 0 ?
334+
similar(x, ntuple(i -> i == dims ? 0 : size(x, i), ndims(x))) :
335+
reverse(selectdim(x, dims, 1:lpad); dims)
336+
xr = rpad == 0 ?
337+
similar(x, ntuple(i -> i == dims ? 0 : size(x, i), ndims(x))) :
338+
reverse(selectdim(x, dims, n-rpad+1:n); dims)
331339
return cat(xl, x, xr; dims)
332340
end
333341

test/padding.jl

Lines changed: 60 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,84 +1,84 @@
11
using NNlib: pad_constant, pad_repeat, pad_zeros, pad_reflect, pad_symmetric, pad_circular
22

33
@testset "padding constant" begin
4-
x = rand(2, 2, 2)
5-
4+
x = rand(2, 2, 2)
5+
66
p = NNlib.gen_pad((1,2,3,4,5,6), (1,2,3), 4)
77
@test p == ((1, 2), (3, 4), (5, 6), (0, 0))
8-
8+
99
@test_throws ArgumentError NNlib.gen_pad((1,2,3,4,5,), (1,2,3), 4)
10-
10+
1111
p = NNlib.gen_pad((1,3), (1,3), 4)
1212
@test p == ((1, 1), (0, 0), (3, 3), (0, 0))
13-
13+
1414
p = NNlib.gen_pad(1, (1,2,3), 4)
1515
@test p == ((1, 1), (1, 1), (1, 1), (0, 0))
16-
16+
1717
p = NNlib.gen_pad(3, :, 2)
1818
@test p == ((3, 3), (3, 3))
1919

2020
p = NNlib.gen_pad((1,0), 1, 2)
2121
@test p == ((1,0), (0,0))
22-
22+
2323
y = pad_constant(x, (3, 2, 4))
2424
@test size(y) == (8, 6, 10)
2525
@test y[4:5, 3:4, 5:6] x
2626
y[4:5, 3:4, 5:6] .= 0
2727
@test all(y .== 0)
28-
28+
2929
@test pad_constant(x, (3, 2, 4)) pad_zeros(x, (3, 2, 4))
30-
@test pad_zeros(x, 2) pad_zeros(x, (2,2,2))
31-
30+
@test pad_zeros(x, 2) pad_zeros(x, (2,2,2))
31+
3232
y = pad_constant(x, (3, 2, 4, 5), 1.2, dims = (1,3))
3333
@test size(y) == (7, 2, 11)
3434
@test y[4:5, 1:2, 5:6] x
3535
y[4:5, 1:2, 5:6] .= 1.2
3636
@test all(y .== 1.2)
37-
37+
3838
@test pad_constant(x, (2,2,2,2), 1.2, dims = (1,3))
3939
pad_constant(x, 2, 1.2, dims = (1,3))
40-
40+
4141
@test pad_constant(x, 1, dims = 1:2) ==
42-
pad_constant(x, 1, dims = (1,2))
43-
42+
pad_constant(x, 1, dims = (1,2))
43+
4444
@test size(pad_constant(x, 1, dims = 1)) == (4,2,2)
45-
45+
4646
@test all(pad_zeros(randn(2), (1, 2))[[1, 4, 5]] .== 0)
47-
47+
4848
gradtest(x -> pad_constant(x, 2), rand(2,2,2))
4949
gradtest(x -> pad_constant(x, (2, 1, 1, 2)), rand(2,2))
5050
gradtest(x -> pad_constant(x, (2, 1,)), rand(2))
5151
end
5252

5353
@testset "padding repeat" begin
54-
x = rand(2, 2, 2)
55-
54+
x = rand(2, 2, 2)
55+
5656
# y = @inferred pad_repeat(x, (3, 2, 4, 5))
5757
y = pad_repeat(x, (3, 2, 4, 5))
5858
@test size(y) == (7, 11, 2)
5959
@test y[4:5, 5:6, :] x
60-
60+
6161
# y = @inferred pad_repeat(x, (3, 2, 4, 5), dims=(1,3))
6262
y = pad_repeat(x, (3, 2, 4, 5), dims=(1,3))
6363
@test size(y) == (7, 2, 11)
6464
@test y[4:5, :, 5:6] x
65-
65+
6666
@test pad_repeat(reshape(1:9, 3, 3), (1,2)) ==
6767
[1 4 7
6868
1 4 7
6969
2 5 8
7070
3 6 9
7171
3 6 9
7272
3 6 9]
73-
73+
7474
@test pad_repeat(reshape(1:9, 3, 3), (2,2), dims=2) ==
7575
[1 1 1 4 7 7 7
7676
2 2 2 5 8 8 8
7777
3 3 3 6 9 9 9]
78-
78+
7979
@test pad_repeat(x, (2, 2, 2, 2), dims=(1,3))
8080
pad_repeat(x, 2, dims=(1,3))
81-
81+
8282
gradtest(x -> pad_repeat(x, (2,2,2,2)), rand(2,2,2))
8383
end
8484

@@ -87,7 +87,7 @@ end
8787
@test y == [7 4 1 4 7 4 1
8888
8 5 2 5 8 5 2
8989
9 6 3 6 9 6 3]
90-
90+
9191
y = pad_reflect(reshape(1:9, 3, 3), (2,2,2,2))
9292
@test y == [9 6 3 6 9 6 3
9393
8 5 2 5 8 5 2
@@ -96,22 +96,34 @@ end
9696
9 6 3 6 9 6 3
9797
8 5 2 5 8 5 2
9898
7 4 1 4 7 4 1]
99-
100-
x = rand(4, 4, 4)
99+
100+
x = rand(4, 4, 4)
101101
@test pad_reflect(x, (2, 2, 2, 2), dims=(1,3))
102102
pad_reflect(x, 2, dims=(1,3))
103-
104-
# pad_reflect needs larger test input as padding must
103+
104+
# pad_reflect needs larger test input as padding must
105105
# be strictly less than array size in that dimension
106106
gradtest(x -> pad_reflect(x, (2,2,2,2)), rand(3,3,3))
107+
108+
x = reshape(1:9, 3, 3, 1, 1)
109+
@test NNlib.pad_reflect(x, (1, 0, 1, 0); dims=1:2) == [
110+
5 2 5 8;
111+
4 1 4 7;
112+
5 2 5 8;
113+
6 3 6 9;;;;]
114+
@test NNlib.pad_reflect(x, (0, 1, 0, 1); dims=1:2) == [
115+
1 4 7 4;
116+
2 5 8 5;
117+
3 6 9 6;
118+
2 5 8 5;;;;]
107119
end
108120

109121
@testset "padding symmetric" begin
110122
y = pad_symmetric(reshape(1:9, 3, 3), (2,2), dims=2)
111123
@test y == [4 1 1 4 7 7 4
112124
5 2 2 5 8 8 5
113125
6 3 3 6 9 9 6]
114-
126+
115127
y = pad_symmetric(reshape(1:9, 3, 3), (2,2,2,2))
116128
@test y == [5 2 2 5 8 8 5
117129
4 1 1 4 7 7 4
@@ -120,20 +132,32 @@ end
120132
6 3 3 6 9 9 6
121133
6 3 3 6 9 9 6
122134
5 2 2 5 8 8 5]
123-
124-
x = rand(4, 4, 4)
135+
136+
x = rand(4, 4, 4)
125137
@test pad_symmetric(x, (2, 2, 2, 2), dims=(1,3))
126138
pad_symmetric(x, 2, dims=(1,3))
127-
139+
128140
gradtest(x -> pad_symmetric(x, (2,2,2,2)), rand(2,2,2))
141+
142+
x = reshape(1:9, 3, 3, 1, 1)
143+
@test NNlib.pad_symmetric(x, (1, 0, 1, 0); dims=1:2) == [
144+
1 1 4 7;
145+
1 1 4 7;
146+
2 2 5 8;
147+
3 3 6 9;;;;]
148+
@test NNlib.pad_symmetric(x, (0, 1, 0, 1); dims=1:2) == [
149+
1 4 7 7;
150+
2 5 8 8;
151+
3 6 9 9;
152+
3 6 9 9;;;;]
129153
end
130154

131155
@testset "padding circular" begin
132156
y = pad_circular(reshape(1:9, 3, 3), (2,2), dims=2)
133157
@test y == [4 7 1 4 7 1 4
134158
5 8 2 5 8 2 5
135159
6 9 3 6 9 3 6]
136-
160+
137161
y = pad_circular(reshape(1:9, 3, 3), (2,2,2,2))
138162
@test y == [5 8 2 5 8 2 5
139163
6 9 3 6 9 3 6
@@ -142,10 +166,10 @@ end
142166
6 9 3 6 9 3 6
143167
4 7 1 4 7 1 4
144168
5 8 2 5 8 2 5]
145-
146-
x = rand(4, 4, 4)
169+
170+
x = rand(4, 4, 4)
147171
@test pad_circular(x, (2, 2, 2, 2), dims=(1,3))
148172
pad_circular(x, 2, dims=(1,3))
149-
173+
150174
gradtest(x -> pad_circular(x, (2,2,2,2)), rand(2,2,2))
151175
end

0 commit comments

Comments
 (0)