Skip to content

Commit a36f15b

Browse files
authored
Add fold and unfold (#444)
* fold/unfold added * fold kernel flipping * docs, fix semicolon error * unfold flipped=true default, added to docs, rrule test * doc example fix for julia 1.6 compat. * removed fold/unfold from export
1 parent f7597d9 commit a36f15b

File tree

5 files changed

+247
-0
lines changed

5 files changed

+247
-0
lines changed

docs/src/reference.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ ConvDims
7171
depthwiseconv
7272
DepthwiseConvDims
7373
DenseConvDims
74+
unfold
75+
fold
7476
```
7577

7678
## Upsampling

src/NNlib.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter,
6161
include("conv_bias_act.jl")
6262
export conv_bias_act, conv_bias_act!
6363

64+
include("fold.jl")
65+
6466
include("ctc.jl")
6567
export ctc_loss
6668

src/fold.jl

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
2+
"""
3+
unfold(x, kernel_size; stride = 1, pad = 0, dilation = 0, flipped = true)
4+
5+
Places sliding windows of x into a container tensor of size `(num_windows,
6+
window_size, batchsize)`. The window size is determined by the `prod(spatial dims
7+
of kernel)*input_channels`. The number of sliding windows will match those of
8+
convolution (`conv`) with the same kernel_size and arguments. Note that
9+
by default `conv` flips the spatial dimensions of its kernel (default
10+
`flipped=false`), whereas `unfold` does not (default `flipped=true`).
11+
Uses `NNlib.im2col!` as backend.
12+
13+
See also [`fold`](@ref), the adjoint/transpose operator
14+
and a potential inverse of `unfold`.
15+
16+
# Example
17+
The below example demonstrates that `unfold` uses the same sliding windows as `conv`.
18+
In general [`batched_mul`](@ref) + `unfold` should not be used to achieve convolution.
19+
```jldoctest
20+
julia> x = reshape([100 2 3 40 5 6 700], 7, 1, 1); # 1D data, 1 channel, batch of 1
21+
22+
julia> w = reshape([1 0 -1], 3, 1, 1); # 1D conv kernel of length 3
23+
24+
julia> kws = (pad=1, stride=2, flipped=true); # use same args for conv and unfold
25+
26+
julia> z = NNlib.unfold(x, size(w); kws...)
27+
4×3×1 Array{Int64, 3}:
28+
[:, :, 1] =
29+
0 100 2
30+
2 3 40
31+
40 5 6
32+
6 700 0
33+
34+
julia> y1 = conv(x, w; kws...)
35+
4×1×1 Array{Int64, 3}:
36+
[:, :, 1] =
37+
-2
38+
-38
39+
34
40+
6
41+
42+
julia> y2 = z ⊠ w # ⊠ (\\boxtimes) is NNlib.batched_mul
43+
4×1×1 Array{Int64, 3}:
44+
[:, :, 1] =
45+
-2
46+
-38
47+
34
48+
6
49+
```
50+
"""
51+
function unfold(x::AbstractArray{T, N}, kernel_size::NTuple{K}; stride = 1, pad = 0, dilation = 1, flipped = true) where {T, K, N}
52+
stride = expand(Val(N - 2), stride)
53+
padding = expand(Val(N - 2), pad)
54+
dilation = expand(Val(N - 2), dilation)
55+
cdims = DenseConvDims(size(x), kernel_size; stride, padding, dilation, flipkernel=flipped)
56+
return unfold(x, cdims)
57+
end
58+
59+
"""
60+
fold(y, output_size, kernel_size; stride = 1, pad = 0, dilation = 0, flipped = true)
61+
62+
The adjoint/transpose operator of `unfold`. It accumulates sliding windows from
63+
the output of `unfold` into a container tensor of size `output_size`. An inverse
64+
to `unfold` may be obtained (in some cases) by using `fold` and accounting for scaling issues
65+
with a divisor (see example). Uses `NNlib.col2im!` as backend.
66+
67+
See also [`unfold`](@ref).
68+
69+
# Example
70+
```jldoctest
71+
julia> x = reshape([100 2 3 40 5 6 700], 7, 1, 1); # 1D data, 1 channel, batch of 1
72+
73+
julia> y = NNlib.unfold(x, (3,1,1)) # sliding window of size 3
74+
5×3×1 Array{Int64, 3}:
75+
[:, :, 1] =
76+
100 2 3
77+
2 3 40
78+
3 40 5
79+
40 5 6
80+
5 6 700
81+
82+
julia> z = NNlib.fold(y, size(x), (3,1,1)) # sum of contributions in y. 100 appears once, 40 three times
83+
7×1×1 Array{Int64, 3}:
84+
[:, :, 1] =
85+
100
86+
4
87+
9
88+
120
89+
15
90+
12
91+
700
92+
93+
julia> divisor = NNlib.fold(NNlib.unfold(ones(size(x)...), (3,1,1)), size(x), (3,1,1))
94+
7×1×1 Array{Float64, 3}:
95+
[:, :, 1] =
96+
1.0
97+
2.0
98+
3.0
99+
3.0
100+
3.0
101+
2.0
102+
1.0
103+
104+
julia> z ./ divisor
105+
7×1×1 Array{Float64, 3}:
106+
[:, :, 1] =
107+
100.0
108+
2.0
109+
3.0
110+
40.0
111+
5.0
112+
6.0
113+
700.0
114+
```
115+
In general, an inverse to `unfold` does not exist if `divisor` contains zeros.
116+
"""
117+
function fold(x::AbstractArray{T, 3}, output_size::NTuple{N}, kernel_size::NTuple{K}; stride = 1, pad = 0, dilation = 1, flipped = true) where {T, K, N}
118+
stride = expand(Val(N - 2), stride)
119+
padding = expand(Val(N - 2), pad)
120+
dilation = expand(Val(N - 2), dilation)
121+
cdims = DenseConvDims(output_size, kernel_size; stride, padding, dilation, flipkernel=flipped)
122+
return fold(x, output_size, cdims)
123+
end
124+
125+
# im2col_dims returns (numblocks, blocksize, threadnum) where thread dim is used as thread-local
126+
# workspace for multithreaded conv. Ultimately, we want to threadnum with batchsize.
127+
unfold_dims(cdims::DenseConvDims) = im2col_dims(cdims)[1:2]
128+
129+
# auto-allocating versions
130+
function unfold(x::AbstractArray{T, N}, cdims::DenseConvDims) where {T, N}
131+
y = similar(x, unfold_dims(cdims)..., size(x, N)) # (numblocks, blocksize, batchsize)
132+
return unfold!(y, x, cdims)
133+
end
134+
135+
function fold(y::AbstractArray{T, 3}, output_size::NTuple, cdims::DenseConvDims) where {T}
136+
x = similar(y, output_size)
137+
return fold!(x, y, cdims)
138+
end
139+
140+
# N < 5 -dimension in-place versions
141+
function unfold!(y::AbstractArray{yT, 3}, x::AbstractArray{xT, N}, cdims::DenseConvDims) where {yT, xT, N}
142+
unfold!(
143+
y,
144+
insert_singleton_spatial_dimension(x, 5-N),
145+
insert_singleton_spatial_dimension(cdims, 5-N),
146+
)
147+
return y
148+
end
149+
150+
function fold!(x::AbstractArray{xT, N}, y::AbstractArray{yT, 3}, cdims::DenseConvDims) where {yT, xT, N}
151+
fold!(
152+
insert_singleton_spatial_dimension(x, 5-N),
153+
y,
154+
insert_singleton_spatial_dimension(cdims, 5-N),
155+
)
156+
return x
157+
end
158+
159+
# 5-dimension in-place versions
160+
function unfold!(y::AbstractArray{yT, 3}, x::AbstractArray{xT, 5}, cdims::DenseConvDims) where {yT, xT}
161+
@threads for batch_idx in 1:size(x, 5)
162+
y_slice = view(y, :, :, batch_idx)
163+
im2col!(y_slice, view(x, :, :, :, :, batch_idx), cdims)
164+
end
165+
return y
166+
end
167+
168+
function fold!(x::AbstractArray{xT, 5}, y::AbstractArray{yT, 3}, cdims::DenseConvDims) where {xT, yT}
169+
@threads for batch_idx in 1:size(x, 5)
170+
y_slice = view(y, :, :, batch_idx)
171+
col2im!(view(x, :, :, :, :, batch_idx), y_slice, cdims)
172+
end
173+
return x
174+
end
175+
176+
# reverse diff rules
177+
function rrule(::typeof(unfold), x, cdims::DenseConvDims; kw...)
178+
function unfold_pullback(Δ)
179+
return (
180+
NoTangent(),
181+
fold(unthunk(Δ), size(x), cdims; kw...),
182+
NoTangent(),
183+
)
184+
end
185+
return unfold(x, cdims; kw...), unfold_pullback
186+
end
187+
188+
function rrule(::typeof(fold), x, output_size, cdims::DenseConvDims; kw...)
189+
function fold_pullback(Δ)
190+
return (
191+
NoTangent(),
192+
unfold(unthunk(Δ), cdims; kw...),
193+
NoTangent(),
194+
NoTangent(),
195+
)
196+
end
197+
return fold(x, output_size, cdims; kw...), fold_pullback
198+
end
199+

test/fold.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
using NNlib, Test
2+
3+
@testset "unfold wrapper" begin
4+
x = rand(rng, 16, 16, 3, 10)
5+
w = rand(rng, 5, 5, 3, 2)
6+
@test size(NNlib.unfold(x, size(w))) == (144, 75, 10)
7+
@test size(NNlib.unfold(x, size(w); pad=2)) == (256, 75, 10)
8+
@test size(NNlib.unfold(x, size(w); stride=2)) == (36, 75, 10)
9+
@test size(NNlib.unfold(x, size(w); dilation=2)) == (64, 75, 10)
10+
end
11+
12+
@testset "Inverses: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3)
13+
x = rand(rng, repeat([8], spatial_rank)..., 3, 2)
14+
w = rand(rng, repeat([3], spatial_rank)..., 3, 3)
15+
cdims = DenseConvDims(x, w; padding=1)
16+
y = NNlib.unfold(x, cdims)
17+
z = NNlib.fold(y, size(x), cdims)
18+
divisor = NNlib.fold(NNlib.unfold(ones(eltype(x), size(x)...), cdims), size(x), cdims)
19+
@test isapprox(z ./ divisor, x, rtol=1.0e-7)
20+
21+
# introduce stride
22+
cdims = DenseConvDims(x, w; padding=1, stride=2)
23+
y = NNlib.unfold(x, cdims)
24+
z = NNlib.fold(y, size(x), cdims)
25+
divisor = NNlib.fold(NNlib.unfold(ones(eltype(x), size(x)...), cdims), size(x), cdims)
26+
@test isapprox(z ./ divisor, x, rtol=1.0e-7)
27+
end
28+
29+
@testset "AutoDiff: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3)
30+
x = rand(rng, repeat([5], spatial_rank)..., 3, 2)
31+
w = rand(rng, repeat([3], spatial_rank)..., 3, 3)
32+
cdims = DenseConvDims(x, w)
33+
gradtest(x -> NNlib.unfold(x, cdims), x)
34+
test_rrule(NNlib.unfold, x, cdims)
35+
36+
y = NNlib.unfold(x, cdims)
37+
gradtest(y -> NNlib.fold(y, size(x), cdims), y)
38+
test_rrule(NNlib.fold, y, size(x), cdims)
39+
end
40+

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ include("test_utils.jl")
5252
include("ctc.jl")
5353
end
5454

55+
@testset "Fold/Unfold" begin
56+
include("fold.jl")
57+
end
58+
5559
@testset "Inference" begin
5660
include("inference.jl")
5761
end

0 commit comments

Comments
 (0)