Skip to content

Commit 13bce2f

Browse files
committed
symmetric and circular padding
1 parent 1672035 commit 13bce2f

File tree

4 files changed

+175
-6
lines changed

4 files changed

+175
-6
lines changed

docs/src/reference.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,10 @@ lpnormpool
5757

5858
```@docs
5959
pad_reflect
60-
pad_constant
60+
pad_symmetric
61+
pad_circular
6162
pad_repeat
63+
pad_constant
6264
pad_zeros
6365
```
6466

src/NNlib.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ export maxpool, maxpool!, meanpool, meanpool!, lpnormpool, lpnormpool!,
7575
∇maxpool, ∇maxpool!, ∇meanpool, ∇meanpool!, ∇lpnormpool, ∇lpnormpool!
7676

7777
include("padding.jl")
78-
export pad_constant, pad_repeat, pad_reflect, pad_zeros
78+
export pad_constant, pad_repeat, pad_reflect, pad_zeros, pad_symmetric, pad_circular
7979

8080
include("upsample.jl")
8181
export upsample_nearest, ∇upsample_nearest,

src/padding.jl

Lines changed: 122 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ If `dims` is not given, it defaults to all dimensions.
2323
For integer `pad` input, it is applied on both sides
2424
on every dimension in `dims`.
2525
26-
See also [`pad_zeros`](@ref), [`pad_reflect`](@ref) and [`pad_repeat`](@ref).
26+
See also [`pad_zeros`](@ref), [`pad_repeat`](@ref), [`pad_reflect`](@ref), [`pad_symmetric`](@ref), and [`pad_circular`](@ref).
2727
2828
```jldoctest
2929
julia> r = reshape(1:4, 2, 2)
@@ -174,7 +174,7 @@ on every dimension in `dims`. In this case, `dims`
174174
defaults to the first `ndims(x)-2` dimensions
175175
(i.e. excludes the channel and batch dimension).
176176
177-
See also [`pad_reflect`](@ref) and [`pad_constant`](@ref).
177+
See also [`pad_reflect`](@ref), [`pad_symmetric`](@ref), [`pad_circular`](@ref), and [`pad_constant`](@ref).
178178
179179
```jldoctest
180180
julia> r = reshape(1:9, 3, 3)
@@ -235,7 +235,7 @@ on every dimension in `dims`. In this case, `dims`
235235
defaults to the first `ndims(x)-2` dimensions
236236
(i.e. excludes the channel and batch dimension).
237237
238-
See also [`pad_repeat`](@ref) and [`pad_constant`](@ref).
238+
See also [`pad_repeat`](@ref), [`pad_symmetric`](@ref), [`pad_circular`](@ref), and [`pad_constant`](@ref).
239239
240240
```jldoctest
241241
julia> r = reshape(1:9, 3, 3)
@@ -277,8 +277,127 @@ function pad_reflect(x::AbstractArray{F,N}, pad::NTuple{2,Int};
277277
return cat(xl, x, xr, dims = dims)
278278
end
279279

280+
"""
281+
pad_symmetric(x, pad::Tuple; [dims])
282+
pad_symmetric(x, pad::Int; [dims])
283+
284+
Pad the array `x` reflecting its values symmetrically across the border, i.e. the border values of `x` are present in the padding values, in contrast to [`pad_reflect`](@ref).
285+
286+
`pad` can a tuple of integers `(l1, r1, ..., ln, rn)`
287+
of some length `2n` that specifies the left and right padding size
288+
for each of the dimensions in `dims`. If `dims` is not given,
289+
it defaults to the first `n` dimensions.
290+
291+
For integer `pad` input instead, it is applied on both sides
292+
on every dimension in `dims`. In this case, `dims`
293+
defaults to the first `ndims(x)-2` dimensions
294+
(i.e. excludes the channel and batch dimension).
295+
296+
See also [`pad_repeat`](@ref), [`pad_reflect`](@ref), [`pad_circular`](@ref), and [`pad_constant`](@ref).
297+
298+
```jldoctest
299+
julia> r = reshape(1:9, 3, 3)
300+
3×3 reshape(::UnitRange{Int64}, 3, 3) with eltype Int64:
301+
1 4 7
302+
2 5 8
303+
3 6 9
304+
305+
julia> NNlib.pad_symmetric(r, (1,2,1,2))
306+
6×6 Matrix{Int64}:
307+
1 1 4 7 7 4
308+
1 1 4 7 7 4
309+
2 2 5 8 8 5
310+
3 3 6 9 9 6
311+
3 3 6 9 9 6
312+
2 2 5 8 8 5
313+
```
314+
"""
315+
function pad_symmetric(x::AbstractArray, pad::NTuple{M,Int};
316+
dims=1:M÷2) where M
317+
length(dims) == M ÷ 2 ||
318+
throw(ArgumentError("The number of dims should be equal to the number of padding dimensions"))
319+
for (i, d) in enumerate(dims)
320+
x = pad_symmetric(x, (pad[2i-1], pad[2i]); dims = d)
321+
end
322+
return x
323+
end
324+
325+
function pad_symmetric(x::AbstractArray{F,N}, pad::NTuple{2,Int};
326+
dims::Int = 1) where {F,N}
327+
lpad, rpad = pad
328+
329+
n = size(x, dims)
330+
xl = selectdim(x, dims, lpad:-1:1)
331+
xr = selectdim(x, dims, n:-1:n-rpad+1)
332+
return cat(xl, x, xr, dims = dims)
333+
end
334+
335+
"""
336+
pad_circular(x, pad::Tuple; [dims])
337+
pad_circular(x, pad::Int; [dims])
338+
339+
Pad the array `x` "circularly" across the border by wrapping around values from the opposite side of `x`.
340+
341+
`pad` can a tuple of integers `(l1, r1, ..., ln, rn)`
342+
of some length `2n` that specifies the left and right padding size
343+
for each of the dimensions in `dims`. If `dims` is not given,
344+
it defaults to the first `n` dimensions.
345+
346+
For integer `pad` input instead, it is applied on both sides
347+
on every dimension in `dims`. In this case, `dims`
348+
defaults to the first `ndims(x)-2` dimensions
349+
(i.e. excludes the channel and batch dimension).
350+
351+
The pad length in any dimension must not exceed the
352+
size of `x` in that dimension.
353+
354+
See also [`pad_repeat`](@ref), [`pad_reflect`](@ref), [`pad_symmetric`](@ref), and [`pad_constant`](@ref).
355+
356+
```jldoctest
357+
julia> r = reshape(1:9, 3, 3)
358+
3×3 reshape(::UnitRange{Int64}, 3, 3) with eltype Int64:
359+
1 4 7
360+
2 5 8
361+
3 6 9
362+
363+
julia> NNlib.pad_circular(r, (1,2,1,2))
364+
6×6 Matrix{Int64}:
365+
9 3 6 9 3 6
366+
7 1 4 7 1 4
367+
8 2 5 8 2 5
368+
9 3 6 9 3 6
369+
7 1 4 7 1 4
370+
8 2 5 8 2 5
371+
```
372+
"""
373+
function pad_circular(x::AbstractArray, pad::NTuple{M,Int};
374+
dims=1:M÷2) where M
375+
length(dims) == M ÷ 2 ||
376+
throw(ArgumentError("The number of dims should be equal to the number of padding dimensions"))
377+
378+
for (i, d) in enumerate(dims)
379+
x = pad_circular(x, (pad[2i-1], pad[2i]); dims = d)
380+
end
381+
return x
382+
end
383+
384+
function pad_circular(x::AbstractArray{F,N}, pad::NTuple{2,Int};
385+
dims::Int = 1) where {F,N}
386+
lpad, rpad = pad
387+
n = size(x, dims)
388+
389+
xl = selectdim(x, dims, n-lpad+1:n)
390+
xr = selectdim(x, dims, 1:rpad)
391+
return cat(xl, x, xr, dims = dims)
392+
end
393+
280394
# convenience methods for symmetric and homogeneous padding
281395
pad_repeat(x::AbstractArray{F,N}, pad::Int; dims=1:N-2) where {F,N} =
282396
pad_repeat(x, ntuple(_ -> pad, 2length(dims)); dims = dims)
283397
pad_reflect(x::AbstractArray{F,N}, pad::Int; dims=1:N-2) where {F,N} =
284398
pad_reflect(x, ntuple(_ -> pad, 2length(dims)); dims = dims)
399+
pad_symmetric(x::AbstractArray{F,N}, pad::Int; dims=1:N-2) where {F,N} =
400+
pad_symmetric(x, ntuple(_ -> pad, 2length(dims)); dims = dims)
401+
pad_circular(x::AbstractArray{F,N}, pad::Int; dims=1:N-2) where {F,N} =
402+
pad_circular(x, ntuple(_ -> pad, 2length(dims)); dims = dims)
403+

test/padding.jl

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,5 +99,53 @@ end
9999
@test pad_reflect(x, (2, 2, 2, 2), dims=(1,3))
100100
pad_reflect(x, 2, dims=(1,3))
101101

102-
gradtest(x -> pad_repeat(x, (2,2,2,2)), rand(2,2,2))
102+
# pad_reflect needs larger test input as padding must
103+
# be strictly less than array size in that dimension
104+
gradtest(x -> pad_reflect(x, (2,2,2,2)), rand(3,3,3))
105+
end
106+
107+
@testset "padding symmetric" begin
108+
y = pad_symmetric(reshape(1:9, 3, 3), (2,2), dims=2)
109+
@test y == [ 4 1 1 4 7 7 4
110+
5 2 2 5 8 8 5
111+
6 3 3 6 9 9 6]
112+
113+
y = pad_symmetric(reshape(1:9, 3, 3), (2,2,2,2))
114+
@test y == [5 2 2 5 8 8 5
115+
4 1 1 4 7 7 4
116+
4 1 1 4 7 7 4
117+
5 2 2 5 8 8 5
118+
6 3 3 6 9 9 6
119+
6 3 3 6 9 9 6
120+
5 2 2 5 8 8 5]
121+
122+
x = rand(4, 4, 4)
123+
124+
@test pad_symmetric(x, (2, 2, 2, 2), dims=(1,3))
125+
pad_symmetric(x, 2, dims=(1,3))
126+
127+
gradtest(x -> pad_symmetric(x, (2,2,2,2)), rand(2,2,2))
128+
end
129+
130+
@testset "padding circular" begin
131+
y = pad_circular(reshape(1:9, 3, 3), (2,2), dims=2)
132+
@test y == [ 4 7 1 4 7 1 4
133+
5 8 2 5 8 2 5
134+
6 9 3 6 9 3 6]
135+
136+
y = pad_circular(reshape(1:9, 3, 3), (2,2,2,2))
137+
@test y == [5 8 2 5 8 2 5
138+
6 9 3 6 9 3 6
139+
4 7 1 4 7 1 4
140+
5 8 2 5 8 2 5
141+
6 9 3 6 9 3 6
142+
4 7 1 4 7 1 4
143+
5 8 2 5 8 2 5]
144+
145+
x = rand(4, 4, 4)
146+
147+
@test pad_circular(x, (2, 2, 2, 2), dims=(1,3))
148+
pad_circular(x, 2, dims=(1,3))
149+
150+
gradtest(x -> pad_circular(x, (2,2,2,2)), rand(2,2,2))
103151
end

0 commit comments

Comments
 (0)