11using NNlib: pad_constant, pad_repeat, pad_zeros, pad_reflect, pad_symmetric, pad_circular
22
33@testset " padding constant" begin
4- x = rand (2 , 2 , 2 )
5-
4+ x = rand (2 , 2 , 2 )
5+
66 p = NNlib. gen_pad ((1 ,2 ,3 ,4 ,5 ,6 ), (1 ,2 ,3 ), 4 )
77 @test p == ((1 , 2 ), (3 , 4 ), (5 , 6 ), (0 , 0 ))
8-
8+
99 @test_throws ArgumentError NNlib. gen_pad ((1 ,2 ,3 ,4 ,5 ,), (1 ,2 ,3 ), 4 )
10-
10+
1111 p = NNlib. gen_pad ((1 ,3 ), (1 ,3 ), 4 )
1212 @test p == ((1 , 1 ), (0 , 0 ), (3 , 3 ), (0 , 0 ))
13-
13+
1414 p = NNlib. gen_pad (1 , (1 ,2 ,3 ), 4 )
1515 @test p == ((1 , 1 ), (1 , 1 ), (1 , 1 ), (0 , 0 ))
16-
16+
1717 p = NNlib. gen_pad (3 , :, 2 )
1818 @test p == ((3 , 3 ), (3 , 3 ))
1919
2020 p = NNlib. gen_pad ((1 ,0 ), 1 , 2 )
2121 @test p == ((1 ,0 ), (0 ,0 ))
22-
22+
2323 y = pad_constant (x, (3 , 2 , 4 ))
2424 @test size (y) == (8 , 6 , 10 )
2525 @test y[4 : 5 , 3 : 4 , 5 : 6 ] ≈ x
2626 y[4 : 5 , 3 : 4 , 5 : 6 ] .= 0
2727 @test all (y .== 0 )
28-
28+
2929 @test pad_constant (x, (3 , 2 , 4 )) ≈ pad_zeros (x, (3 , 2 , 4 ))
30- @test pad_zeros (x, 2 ) ≈ pad_zeros (x, (2 ,2 ,2 ))
31-
30+ @test pad_zeros (x, 2 ) ≈ pad_zeros (x, (2 ,2 ,2 ))
31+
3232 y = pad_constant (x, (3 , 2 , 4 , 5 ), 1.2 , dims = (1 ,3 ))
3333 @test size (y) == (7 , 2 , 11 )
3434 @test y[4 : 5 , 1 : 2 , 5 : 6 ] ≈ x
3535 y[4 : 5 , 1 : 2 , 5 : 6 ] .= 1.2
3636 @test all (y .== 1.2 )
37-
37+
3838 @test pad_constant (x, (2 ,2 ,2 ,2 ), 1.2 , dims = (1 ,3 )) ≈
3939 pad_constant (x, 2 , 1.2 , dims = (1 ,3 ))
40-
40+
4141 @test pad_constant (x, 1 , dims = 1 : 2 ) ==
42- pad_constant (x, 1 , dims = (1 ,2 ))
43-
42+ pad_constant (x, 1 , dims = (1 ,2 ))
43+
4444 @test size (pad_constant (x, 1 , dims = 1 )) == (4 ,2 ,2 )
45-
45+
4646 @test all (pad_zeros (randn (2 ), (1 , 2 ))[[1 , 4 , 5 ]] .== 0 )
47-
47+
4848 gradtest (x -> pad_constant (x, 2 ), rand (2 ,2 ,2 ))
4949 gradtest (x -> pad_constant (x, (2 , 1 , 1 , 2 )), rand (2 ,2 ))
5050 gradtest (x -> pad_constant (x, (2 , 1 ,)), rand (2 ))
5151end
5252
5353@testset " padding repeat" begin
54- x = rand (2 , 2 , 2 )
55-
54+ x = rand (2 , 2 , 2 )
55+
5656 # y = @inferred pad_repeat(x, (3, 2, 4, 5))
5757 y = pad_repeat (x, (3 , 2 , 4 , 5 ))
5858 @test size (y) == (7 , 11 , 2 )
5959 @test y[4 : 5 , 5 : 6 , :] ≈ x
60-
60+
6161 # y = @inferred pad_repeat(x, (3, 2, 4, 5), dims=(1,3))
6262 y = pad_repeat (x, (3 , 2 , 4 , 5 ), dims= (1 ,3 ))
6363 @test size (y) == (7 , 2 , 11 )
6464 @test y[4 : 5 , :, 5 : 6 ] ≈ x
65-
65+
6666 @test pad_repeat (reshape (1 : 9 , 3 , 3 ), (1 ,2 )) ==
6767 [1 4 7
6868 1 4 7
6969 2 5 8
7070 3 6 9
7171 3 6 9
7272 3 6 9 ]
73-
73+
7474 @test pad_repeat (reshape (1 : 9 , 3 , 3 ), (2 ,2 ), dims= 2 ) ==
7575 [1 1 1 4 7 7 7
7676 2 2 2 5 8 8 8
7777 3 3 3 6 9 9 9 ]
78-
78+
7979 @test pad_repeat (x, (2 , 2 , 2 , 2 ), dims= (1 ,3 )) ≈
8080 pad_repeat (x, 2 , dims= (1 ,3 ))
81-
81+
8282 gradtest (x -> pad_repeat (x, (2 ,2 ,2 ,2 )), rand (2 ,2 ,2 ))
8383end
8484
8787 @test y == [7 4 1 4 7 4 1
8888 8 5 2 5 8 5 2
8989 9 6 3 6 9 6 3 ]
90-
90+
9191 y = pad_reflect (reshape (1 : 9 , 3 , 3 ), (2 ,2 ,2 ,2 ))
9292 @test y == [9 6 3 6 9 6 3
9393 8 5 2 5 8 5 2
9696 9 6 3 6 9 6 3
9797 8 5 2 5 8 5 2
9898 7 4 1 4 7 4 1 ]
99-
100- x = rand (4 , 4 , 4 )
99+
100+ x = rand (4 , 4 , 4 )
101101 @test pad_reflect (x, (2 , 2 , 2 , 2 ), dims= (1 ,3 )) ≈
102102 pad_reflect (x, 2 , dims= (1 ,3 ))
103-
104- # pad_reflect needs larger test input as padding must
103+
104+ # pad_reflect needs larger test input as padding must
105105 # be strictly less than array size in that dimension
106106 gradtest (x -> pad_reflect (x, (2 ,2 ,2 ,2 )), rand (3 ,3 ,3 ))
107+
108+ x = reshape (1 : 9 , 3 , 3 , 1 , 1 )
109+ @test NNlib. pad_reflect (x, (1 , 0 , 1 , 0 ); dims= 1 : 2 ) == [
110+ 5 2 5 8 ;
111+ 4 1 4 7 ;
112+ 5 2 5 8 ;
113+ 6 3 6 9 ;;;;]
114+ @test NNlib. pad_reflect (x, (0 , 1 , 0 , 1 ); dims= 1 : 2 ) == [
115+ 1 4 7 4 ;
116+ 2 5 8 5 ;
117+ 3 6 9 6 ;
118+ 2 5 8 5 ;;;;]
107119end
108120
109121@testset " padding symmetric" begin
110122 y = pad_symmetric (reshape (1 : 9 , 3 , 3 ), (2 ,2 ), dims= 2 )
111123 @test y == [4 1 1 4 7 7 4
112124 5 2 2 5 8 8 5
113125 6 3 3 6 9 9 6 ]
114-
126+
115127 y = pad_symmetric (reshape (1 : 9 , 3 , 3 ), (2 ,2 ,2 ,2 ))
116128 @test y == [5 2 2 5 8 8 5
117129 4 1 1 4 7 7 4
@@ -120,20 +132,32 @@ end
120132 6 3 3 6 9 9 6
121133 6 3 3 6 9 9 6
122134 5 2 2 5 8 8 5 ]
123-
124- x = rand (4 , 4 , 4 )
135+
136+ x = rand (4 , 4 , 4 )
125137 @test pad_symmetric (x, (2 , 2 , 2 , 2 ), dims= (1 ,3 )) ≈
126138 pad_symmetric (x, 2 , dims= (1 ,3 ))
127-
139+
128140 gradtest (x -> pad_symmetric (x, (2 ,2 ,2 ,2 )), rand (2 ,2 ,2 ))
141+
142+ x = reshape (1 : 9 , 3 , 3 , 1 , 1 )
143+ @test NNlib. pad_symmetric (x, (1 , 0 , 1 , 0 ); dims= 1 : 2 ) == [
144+ 1 1 4 7 ;
145+ 1 1 4 7 ;
146+ 2 2 5 8 ;
147+ 3 3 6 9 ;;;;]
148+ @test NNlib. pad_symmetric (x, (0 , 1 , 0 , 1 ); dims= 1 : 2 ) == [
149+ 1 4 7 7 ;
150+ 2 5 8 8 ;
151+ 3 6 9 9 ;
152+ 3 6 9 9 ;;;;]
129153end
130154
131155@testset " padding circular" begin
132156 y = pad_circular (reshape (1 : 9 , 3 , 3 ), (2 ,2 ), dims= 2 )
133157 @test y == [4 7 1 4 7 1 4
134158 5 8 2 5 8 2 5
135159 6 9 3 6 9 3 6 ]
136-
160+
137161 y = pad_circular (reshape (1 : 9 , 3 , 3 ), (2 ,2 ,2 ,2 ))
138162 @test y == [5 8 2 5 8 2 5
139163 6 9 3 6 9 3 6
@@ -142,10 +166,10 @@ end
142166 6 9 3 6 9 3 6
143167 4 7 1 4 7 1 4
144168 5 8 2 5 8 2 5 ]
145-
146- x = rand (4 , 4 , 4 )
169+
170+ x = rand (4 , 4 , 4 )
147171 @test pad_circular (x, (2 , 2 , 2 , 2 ), dims= (1 ,3 )) ≈
148172 pad_circular (x, 2 , dims= (1 ,3 ))
149-
173+
150174 gradtest (x -> pad_circular (x, (2 ,2 ,2 ,2 )), rand (2 ,2 ,2 ))
151175end
0 commit comments