Skip to content

Commit a4111c1

Browse files
authored
Use KernelAbstractions for fold/unfold (#596)
1 parent 52f22f9 commit a4111c1

File tree

9 files changed

+176
-177
lines changed

9 files changed

+176
-177
lines changed

ext/NNlibCUDAExt/NNlibCUDAExt.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ include("activations.jl")
99
include("batchedadjtrans.jl")
1010
include("batchedmul.jl")
1111
include("ctc.jl")
12-
include("fold.jl")
1312
include("scatter.jl")
1413
include("utils.jl")
1514

ext/NNlibCUDAExt/fold.jl

Lines changed: 0 additions & 111 deletions
This file was deleted.

ext/NNlibFFTWExt/stft.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ function NNlib.stft(x;
4141
ids = [
4242
row + hop_length * col
4343
for row in 1:n_fft, col in 0:(n_frames - 1)]
44-
x = x[ids, ntuple(_ -> Colon(), ndims(x) - 1)...]
44+
x = @inbounds x[ids, ntuple(_ -> Colon(), ndims(x) - 1)...]
4545
end
4646

4747
region = 1
@@ -113,7 +113,7 @@ function NNlib.istft(y;
113113
# In case of batched input, reshaped it (n_fft, n_frames, batch) -> (:, batch).
114114
nd = ntuple(_ -> Colon(), ndims(x) - 2)
115115
ndims(x) == 3 && (x = reshape(x, (:, size(x, 3)));)
116-
x = x[ids, nd...]
116+
x = @inbounds x[ids, nd...]
117117

118118
# Trim padding.
119119
left = center ? (n_fft ÷ 2 + 1) : 1

src/audio/spectrogram.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ function spectrogram(waveform;
4141
window_normalized && (spec = spec .* inv(norm(window));)
4242

4343
if power > 0
44-
p = real(eltype(spec)(power))
44+
p = eltype(waveform)(power)
4545
spec = abs.(spec).^p
4646
end
4747
return spec

src/audio/stft.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ and ``m`` is the index of the sliding window.
149149
- `x`: Input, must be either a 1D time sequence (`(L,)` shape)
150150
or a 2D batch of time sequence (`(L, B)` shape).
151151
152-
# Positional Arguments:
152+
# Keyword Arguments:
153153
154154
- `n_fft::Int`: Size of Fourier transform.
155155
- `hop_length::Int`: Distance between neighboring sliding window frames.
@@ -187,7 +187,7 @@ Return the least squares estimation of the original signal
187187
- `y`: Input complex array in the `(n_fft, n_frames, B)` shape.
188188
Where `B` is the optional batch dimension.
189189
190-
# Positional Arguments:
190+
# Keyword Arguments:
191191
192192
- `n_fft::Int`: Size of Fourier transform.
193193
- `hop_length::Int`: Distance between neighboring sliding window frames.

src/fold.jl

Lines changed: 119 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
"""
32
unfold(x, kernel_size; stride = 1, pad = 0, dilation = 0, flipped = true)
43
@@ -7,10 +6,10 @@ window_size, batchsize)`. The window size is determined by the `prod(spatial dim
76
of kernel)*input_channels`. The number of sliding windows will match those of
87
convolution (`conv`) with the same kernel_size and arguments. Note that
98
by default `conv` flips the spatial dimensions of its kernel (default
10-
`flipped=false`), whereas `unfold` does not (default `flipped=true`).
11-
Uses `NNlib.im2col!` as backend.
9+
`flipped=false`), whereas `unfold` does not (default `flipped=true`).
10+
Uses `NNlib.im2col!` as backend.
1211
13-
See also [`fold`](@ref), the adjoint/transpose operator
12+
See also [`fold`](@ref), the adjoint/transpose operator
1413
and a potential inverse of `unfold`.
1514
1615
# Example
@@ -23,7 +22,7 @@ julia> w = reshape([1 0 -1], 3, 1, 1); # 1D conv kernel of length 3
2322
2423
julia> kws = (pad=1, stride=2, flipped=true); # use same args for conv and unfold
2524
26-
julia> z = NNlib.unfold(x, size(w); kws...)
25+
julia> z = NNlib.unfold(x, size(w); kws...)
2726
4×3×1 Array{Int64, 3}:
2827
[:, :, 1] =
2928
0 100 2
@@ -61,8 +60,8 @@ end
6160
6261
The adjoint/transpose operator of `unfold`. It accumulates sliding windows from
6362
the output of `unfold` into a container tensor of size `output_size`. An inverse
64-
to `unfold` may be obtained (in some cases) by using `fold` and accounting for scaling issues
65-
with a divisor (see example). Uses `NNlib.col2im!` as backend.
63+
to `unfold` may be obtained (in some cases) by using `fold` and accounting for scaling issues
64+
with a divisor (see example). Uses `NNlib.col2im!` as backend.
6665
6766
See also [`unfold`](@ref).
6867
@@ -101,7 +100,7 @@ julia> divisor = NNlib.fold(NNlib.unfold(ones(size(x)...), (3,1,1)), size(x), (3
101100
2.0
102101
1.0
103102
104-
julia> z ./ divisor
103+
julia> z ./ divisor
105104
7×1×1 Array{Float64, 3}:
106105
[:, :, 1] =
107106
100.0
@@ -133,30 +132,30 @@ function unfold(x::AbstractArray{T, N}, cdims::DenseConvDims) where {T, N}
133132
end
134133

135134
function fold(y::AbstractArray{T, 3}, output_size::NTuple, cdims::DenseConvDims) where {T}
136-
x = similar(y, output_size)
135+
x = similar(y, output_size)
137136
return fold!(x, y, cdims)
138137
end
139138

140-
# N < 5 -dimension in-place versions
139+
# N < 5 -dimension in-place versions
141140
function unfold!(y::AbstractArray{yT, 3}, x::AbstractArray{xT, N}, cdims::DenseConvDims) where {yT, xT, N}
142141
unfold!(
143-
y,
144-
insert_singleton_spatial_dimension(x, 5-N),
145-
insert_singleton_spatial_dimension(cdims, 5-N),
142+
y,
143+
insert_singleton_spatial_dimension(x, 5-N),
144+
insert_singleton_spatial_dimension(cdims, 5-N),
146145
)
147146
return y
148147
end
149148

150149
function fold!(x::AbstractArray{xT, N}, y::AbstractArray{yT, 3}, cdims::DenseConvDims) where {yT, xT, N}
151150
fold!(
152-
insert_singleton_spatial_dimension(x, 5-N),
151+
insert_singleton_spatial_dimension(x, 5-N),
153152
y,
154-
insert_singleton_spatial_dimension(cdims, 5-N),
153+
insert_singleton_spatial_dimension(cdims, 5-N),
155154
)
156155
return x
157156
end
158157

159-
# 5-dimension in-place versions
158+
# 5-dimension in-place versions
160159
function unfold!(y::AbstractArray{yT, 3}, x::AbstractArray{xT, 5}, cdims::DenseConvDims) where {yT, xT}
161160
@threads for batch_idx in 1:size(x, 5)
162161
y_slice = view(y, :, :, batch_idx)
@@ -173,6 +172,110 @@ function fold!(x::AbstractArray{xT, 5}, y::AbstractArray{yT, 3}, cdims::DenseCon
173172
return x
174173
end
175174

175+
@kernel function unfold_kernel!(
176+
col::AbstractArray{T}, x, col_size,
177+
input_size, output_size, kernel_size,
178+
flipkernel, stride, pad_lo, dilation, max_idx,
179+
) where T
180+
index = @index(Global)
181+
182+
@inbounds if index max_idx
183+
i, kw, kh, kd, c, b = CartesianIndices(col_size)[index].I # col indices
184+
w, h, d = CartesianIndices(output_size)[i].I # x indices
185+
186+
# project
187+
w, h, d = @. ((w, h, d) - 1) * stride - pad_lo + 1 + ((kw, kh, kd) - 1) * dilation
188+
189+
if !flipkernel
190+
kw, kh, kd = kernel_size .- (kw, kh, kd) .+ 1
191+
end
192+
193+
# check out of bounds
194+
if !all(checkindex.(Bool, UnitRange.(1, input_size), (w, h, d)))
195+
col[i, kw, kh, kd, c, b] = T(0)
196+
else
197+
xval::T = x[w, h, d, c, b]
198+
col[i, kw, kh, kd, c, b] = xval
199+
end
200+
end
201+
end
202+
203+
@kernel function fold_kernel!(
204+
x::AbstractArray{T}, col, col_size,
205+
input_size, output_size, kernel_size,
206+
flipkernel, stride, pad_lo, dilation, max_idx,
207+
) where T
208+
index = @index(Global)
209+
210+
@inbounds if index max_idx
211+
i, kw, kh, kd, c, b = CartesianIndices(col_size)[index].I # col indices
212+
w, h, d = CartesianIndices(output_size)[i].I # x indices
213+
214+
# project
215+
w, h, d = @. ((w, h, d) - 1) * stride - pad_lo + 1 + ((kw, kh, kd) - 1) * dilation
216+
217+
# check out of bounds
218+
if all(checkindex.(Bool, UnitRange.(1, input_size), (w, h, d)))
219+
if !flipkernel
220+
kw, kh, kd = kernel_size .- (kw, kh, kd) .+ 1
221+
end
222+
223+
cval::T = col[i, kw, kh, kd, c, b]
224+
@atomic x[w, h, d, c, b] += cval
225+
end
226+
end
227+
end
228+
229+
function unfold!(
230+
col::AnyGPUArray{cT,3}, x::AnyGPUArray{xT,5}, cdims::DenseConvDims,
231+
) where {cT, xT}
232+
spatial_dims(cdims) != 3 && throw(DimensionMismatch(
233+
"unfold!() only accepts 3d convoluitional inputs"))
234+
235+
C_in = channels_in(cdims)
236+
ker_size = kernel_size(cdims)
237+
pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(cdims)
238+
pad_lo = (pad_w_lo, pad_h_lo, pad_d_lo)
239+
240+
out_size = output_size(cdims)
241+
col_reshaped = reshape(col, (prod(out_size), ker_size..., C_in, :))
242+
243+
max_idx = prod(size(col))
244+
unfold_kernel!(get_backend(x))(
245+
col_reshaped, x, size(col_reshaped),
246+
input_size(cdims), out_size, ker_size,
247+
flipkernel(cdims), stride(cdims), pad_lo, dilation(cdims), max_idx;
248+
ndrange=max_idx)
249+
return col
250+
end
251+
252+
function fold!(
253+
x::AnyGPUArray{xT,5}, col::AnyGPUArray{cT,3}, cdims::DenseConvDims,
254+
) where {xT, cT}
255+
spatial_dims(cdims) != 3 && throw(DimensionMismatch(
256+
"fold!() only accepts 3d convoluitional inputs"))
257+
258+
# going to accumulate into x
259+
fill!(x, xT(0))
260+
261+
C_in = channels_in(cdims)
262+
ker_size = kernel_size(cdims)
263+
pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(cdims)
264+
pad_lo = (pad_w_lo, pad_h_lo, pad_d_lo)
265+
out_size = output_size(cdims)
266+
267+
col_reshaped = reshape(col, (prod(out_size), ker_size..., C_in, :))
268+
269+
max_idx = prod(size(col))
270+
fold_kernel!(get_backend(x))(
271+
x, col_reshaped, size(col_reshaped),
272+
input_size(cdims), out_size, ker_size,
273+
flipkernel(cdims), stride(cdims), pad_lo, dilation(cdims), max_idx;
274+
ndrange=max_idx)
275+
276+
return x
277+
end
278+
176279
# reverse diff rules
177280
function rrule(::typeof(unfold), x, cdims::DenseConvDims; kw...)
178281
function unfold_pullback(Δ)

0 commit comments

Comments
 (0)