1-
21"""
32 unfold(x, kernel_size; stride = 1, pad = 0, dilation = 0, flipped = true)
43
@@ -7,10 +6,10 @@ window_size, batchsize)`. The window size is determined by the `prod(spatial dim
76of kernel)*input_channels`. The number of sliding windows will match those of
87convolution (`conv`) with the same kernel_size and arguments. Note that
98by default `conv` flips the spatial dimensions of its kernel (default
10- `flipped=false`), whereas `unfold` does not (default `flipped=true`).
11- Uses `NNlib.im2col!` as backend.
9+ `flipped=false`), whereas `unfold` does not (default `flipped=true`).
10+ Uses `NNlib.im2col!` as backend.
1211
13- See also [`fold`](@ref), the adjoint/transpose operator
12+ See also [`fold`](@ref), the adjoint/transpose operator
1413and a potential inverse of `unfold`.
1514
1615# Example
@@ -23,7 +22,7 @@ julia> w = reshape([1 0 -1], 3, 1, 1); # 1D conv kernel of length 3
2322
2423julia> kws = (pad=1, stride=2, flipped=true); # use same args for conv and unfold
2524
26- julia> z = NNlib.unfold(x, size(w); kws...)
25+ julia> z = NNlib.unfold(x, size(w); kws...)
27264×3×1 Array{Int64, 3}:
2827[:, :, 1] =
2928 0 100 2
6160
6261The adjoint/transpose operator of `unfold`. It accumulates sliding windows from
6362the output of `unfold` into a container tensor of size `output_size`. An inverse
64- to `unfold` may be obtained (in some cases) by using `fold` and accounting for scaling issues
65- with a divisor (see example). Uses `NNlib.col2im!` as backend.
63+ to `unfold` may be obtained (in some cases) by using `fold` and accounting for scaling issues
64+ with a divisor (see example). Uses `NNlib.col2im!` as backend.
6665
6766See also [`unfold`](@ref).
6867
@@ -101,7 +100,7 @@ julia> divisor = NNlib.fold(NNlib.unfold(ones(size(x)...), (3,1,1)), size(x), (3
101100 2.0
102101 1.0
103102
104- julia> z ./ divisor
103+ julia> z ./ divisor
1051047×1×1 Array{Float64, 3}:
106105[:, :, 1] =
107106 100.0
@@ -133,30 +132,30 @@ function unfold(x::AbstractArray{T, N}, cdims::DenseConvDims) where {T, N}
133132end
134133
135134function fold (y:: AbstractArray{T, 3} , output_size:: NTuple , cdims:: DenseConvDims ) where {T}
136- x = similar (y, output_size)
135+ x = similar (y, output_size)
137136 return fold! (x, y, cdims)
138137end
139138
140- # N < 5 -dimension in-place versions
139+ # N < 5 -dimension in-place versions
141140function unfold! (y:: AbstractArray{yT, 3} , x:: AbstractArray{xT, N} , cdims:: DenseConvDims ) where {yT, xT, N}
142141 unfold! (
143- y,
144- insert_singleton_spatial_dimension (x, 5 - N),
145- insert_singleton_spatial_dimension (cdims, 5 - N),
142+ y,
143+ insert_singleton_spatial_dimension (x, 5 - N),
144+ insert_singleton_spatial_dimension (cdims, 5 - N),
146145 )
147146 return y
148147end
149148
150149function fold! (x:: AbstractArray{xT, N} , y:: AbstractArray{yT, 3} , cdims:: DenseConvDims ) where {yT, xT, N}
151150 fold! (
152- insert_singleton_spatial_dimension (x, 5 - N),
151+ insert_singleton_spatial_dimension (x, 5 - N),
153152 y,
154- insert_singleton_spatial_dimension (cdims, 5 - N),
153+ insert_singleton_spatial_dimension (cdims, 5 - N),
155154 )
156155 return x
157156end
158157
159- # 5-dimension in-place versions
158+ # 5-dimension in-place versions
160159function unfold! (y:: AbstractArray{yT, 3} , x:: AbstractArray{xT, 5} , cdims:: DenseConvDims ) where {yT, xT}
161160 @threads for batch_idx in 1 : size (x, 5 )
162161 y_slice = view (y, :, :, batch_idx)
@@ -173,6 +172,110 @@ function fold!(x::AbstractArray{xT, 5}, y::AbstractArray{yT, 3}, cdims::DenseCon
173172 return x
174173end
175174
175+ @kernel function unfold_kernel! (
176+ col:: AbstractArray{T} , x, col_size,
177+ input_size, output_size, kernel_size,
178+ flipkernel, stride, pad_lo, dilation, max_idx,
179+ ) where T
180+ index = @index (Global)
181+
182+ @inbounds if index ≤ max_idx
183+ i, kw, kh, kd, c, b = CartesianIndices (col_size)[index]. I # col indices
184+ w, h, d = CartesianIndices (output_size)[i]. I # x indices
185+
186+ # project
187+ w, h, d = @. ((w, h, d) - 1 ) * stride - pad_lo + 1 + ((kw, kh, kd) - 1 ) * dilation
188+
189+ if ! flipkernel
190+ kw, kh, kd = kernel_size .- (kw, kh, kd) .+ 1
191+ end
192+
193+ # check out of bounds
194+ if ! all (checkindex .(Bool, UnitRange .(1 , input_size), (w, h, d)))
195+ col[i, kw, kh, kd, c, b] = T (0 )
196+ else
197+ xval:: T = x[w, h, d, c, b]
198+ col[i, kw, kh, kd, c, b] = xval
199+ end
200+ end
201+ end
202+
203+ @kernel function fold_kernel! (
204+ x:: AbstractArray{T} , col, col_size,
205+ input_size, output_size, kernel_size,
206+ flipkernel, stride, pad_lo, dilation, max_idx,
207+ ) where T
208+ index = @index (Global)
209+
210+ @inbounds if index ≤ max_idx
211+ i, kw, kh, kd, c, b = CartesianIndices (col_size)[index]. I # col indices
212+ w, h, d = CartesianIndices (output_size)[i]. I # x indices
213+
214+ # project
215+ w, h, d = @. ((w, h, d) - 1 ) * stride - pad_lo + 1 + ((kw, kh, kd) - 1 ) * dilation
216+
217+ # check out of bounds
218+ if all (checkindex .(Bool, UnitRange .(1 , input_size), (w, h, d)))
219+ if ! flipkernel
220+ kw, kh, kd = kernel_size .- (kw, kh, kd) .+ 1
221+ end
222+
223+ cval:: T = col[i, kw, kh, kd, c, b]
224+ @atomic x[w, h, d, c, b] += cval
225+ end
226+ end
227+ end
228+
229+ function unfold! (
230+ col:: AnyGPUArray{cT,3} , x:: AnyGPUArray{xT,5} , cdims:: DenseConvDims ,
231+ ) where {cT, xT}
232+ spatial_dims (cdims) != 3 && throw (DimensionMismatch (
233+ " unfold!() only accepts 3d convoluitional inputs" ))
234+
235+ C_in = channels_in (cdims)
236+ ker_size = kernel_size (cdims)
237+ pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding (cdims)
238+ pad_lo = (pad_w_lo, pad_h_lo, pad_d_lo)
239+
240+ out_size = output_size (cdims)
241+ col_reshaped = reshape (col, (prod (out_size), ker_size... , C_in, :))
242+
243+ max_idx = prod (size (col))
244+ unfold_kernel! (get_backend (x))(
245+ col_reshaped, x, size (col_reshaped),
246+ input_size (cdims), out_size, ker_size,
247+ flipkernel (cdims), stride (cdims), pad_lo, dilation (cdims), max_idx;
248+ ndrange= max_idx)
249+ return col
250+ end
251+
252+ function fold! (
253+ x:: AnyGPUArray{xT,5} , col:: AnyGPUArray{cT,3} , cdims:: DenseConvDims ,
254+ ) where {xT, cT}
255+ spatial_dims (cdims) != 3 && throw (DimensionMismatch (
256+ " fold!() only accepts 3d convoluitional inputs" ))
257+
258+ # going to accumulate into x
259+ fill! (x, xT (0 ))
260+
261+ C_in = channels_in (cdims)
262+ ker_size = kernel_size (cdims)
263+ pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding (cdims)
264+ pad_lo = (pad_w_lo, pad_h_lo, pad_d_lo)
265+ out_size = output_size (cdims)
266+
267+ col_reshaped = reshape (col, (prod (out_size), ker_size... , C_in, :))
268+
269+ max_idx = prod (size (col))
270+ fold_kernel! (get_backend (x))(
271+ x, col_reshaped, size (col_reshaped),
272+ input_size (cdims), out_size, ker_size,
273+ flipkernel (cdims), stride (cdims), pad_lo, dilation (cdims), max_idx;
274+ ndrange= max_idx)
275+
276+ return x
277+ end
278+
176279# reverse diff rules
177280function rrule (:: typeof (unfold), x, cdims:: DenseConvDims ; kw... )
178281 function unfold_pullback (Δ)
0 commit comments