Skip to content

Commit 4002114

Browse files
authored
Merge branch 'master' into ap/patch
2 parents 2d11083 + f84c02c commit 4002114

File tree

11 files changed

+417
-116
lines changed

11 files changed

+417
-116
lines changed

ext/NNlibCUDA/Project.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
name = "NNlibCUDA"
22
uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d"
3-
version = "0.2.3"
3+
version = "0.2.5"
44

55
[deps]
6+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
67
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
78
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
89
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
910
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1011
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1112

1213
[compat]
13-
CUDA = "3.3.1"
14-
NNlib = "0.8.3"
14+
Adapt = "3.3"
15+
CUDA = "3.11"
16+
NNlib = "0.8.9"
1517
julia = "1.6"
1618

1719
[extras]

ext/NNlibCUDA/src/NNlibCUDA.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ const IntOrIntTuple = Union{Integer, NTuple{N,<:Integer} where N}
99
include("upsample.jl")
1010
include("sampling.jl")
1111
include("activations.jl")
12+
include("batchedadjtrans.jl")
1213
include("batchedmul.jl")
14+
include("ctc.jl")
1315
include("scatter.jl")
1416
include("gather.jl")
1517
include("utils.jl")

ext/NNlibCUDA/src/batchedadjtrans.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
using NNlib: BatchedAdjoint, BatchedTranspose, BatchedAdjOrTrans
2+
using Adapt
3+
using Adapt: WrappedArray
4+
5+
const CuBatchedAdjoint{T} = BatchedAdjoint{T, <: CuArray{T}}
6+
const CuBatchedTranspose{T} = BatchedTranspose{T, <: CuArray{T}}
7+
const CuBatchedAdjOrTrans{T} = Union{CuBatchedAdjoint{T}, CuBatchedTranspose{T}}
8+
const WrappedCuBatchedAdjOrTrans{T, N} = WrappedArray{T, N, CuBatchedAdjOrTrans{T}, CuBatchedAdjOrTrans{T}}
9+
10+
11+
Base.print_array(io::IO, b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}) = Base.print_array(io, adapt(Array, b))
12+
Base._show_nonempty(io::IO, b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}, prefix::String) = Base._show_nonempty(io, adapt(Array, b), prefix)
13+
Base.show_vector(io::IO, b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}, opn, cls) = Base.show_vector(io, adapt(Array, b), opn, cls)
14+
15+
Base.convert(::Type{T}, b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}) where {T<:Array} = Base.convert(T, adapt(Array, b))
16+
Base.Array{T, N}(b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}) where {T, N} = Array{T, N}(adapt(Array, b))
17+
Base.collect(b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}) = collect(adapt(Array, b))

ext/NNlibCUDA/src/ctc.jl

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# CTC loss moved from Flux.jl to NNlib + NNlibCUDA
2+
3+
import NNlib: ctc_loss, ctc_alpha, ∇ctc_loss
4+
5+
## GPU implementation
6+
7+
# a port of the GPU kernels from Baidu's C++ warp-ctc package,
8+
# which itself is Copyright 2015-2016 Baidu USA LLC
9+
# and available under the Apache 2.0 license
10+
#
11+
# Apache 2.0 license: https://www.apache.org/licenses/LICENSE-2.0
12+
# GitHub: https://github.com/baidu-research/warp-ctc/
13+
# paper: https://arxiv.org/pdf/1512.02595.pdf
14+
15+
const MAX_THREADS = 256
16+
17+
function log_plus_f(p1, p2)
18+
isinf(p1) && return p2
19+
isinf(p2) && return p1
20+
if p1 < p2
21+
p1, p2 = p2, p1
22+
end
23+
return p1 + log(1+exp(p2 - p1))
24+
end
25+
26+
function count_repeats(A)
27+
repeats = 0
28+
for (i,elem) in enumerate(A)
29+
if i > 1 && A[i] == A[i-1]
30+
repeats += 1
31+
end
32+
end
33+
return repeats
34+
end
35+
36+
function compute_alpha_kernel(probs, labelSize, uttLength, repeats, labelsWithoutBlanks, labelsWithBlanks, alpha, blankLabel)
37+
38+
tid = threadIdx().x
39+
L = labelSize
40+
T = uttLength
41+
S = length(labelsWithBlanks)
42+
43+
if L + repeats > T
44+
return nothing
45+
end
46+
labels = labelsWithBlanks
47+
48+
# Corner-case checking
49+
start = (L + repeats <= T) ? 0 : 1
50+
last = S > 1 ? 2 : 1
51+
52+
# Fill in first column (time step)
53+
i = tid
54+
while i <= last - start
55+
alpha[start+i, 1] = probs[labels[start+i], 1]
56+
i += blockDim().x
57+
end
58+
sync_threads()
59+
60+
# Fill in coefficients for each time step
61+
for t=2:T
62+
# Corner-case checking
63+
if tid == 1 && !(1 < S - 2*(T-t) - 1)
64+
if start == 0
65+
alpha[1, t] = alpha[1, t-1] + probs[blankLabel, t]
66+
elseif start == 1
67+
alpha[1, t] = alpha[1, t-1]
68+
end
69+
end
70+
sync_threads()
71+
72+
# Fill in coefficients for each label class in the target output sequence;
73+
# each thread will process the calculations for one class
74+
idx = tid+1
75+
while idx <= S
76+
prevSum = log_plus_f(alpha[idx, t-1], alpha[idx-1, t-1])
77+
if labels[idx] != blankLabel && idx != 2 && labels[idx] != labels[idx-2]
78+
prevSum = log_plus_f(prevSum, alpha[idx-2, t-1])
79+
end
80+
if idx < S - 2*(T-t) - 1
81+
alpha[idx, t] = -Inf32
82+
else
83+
alpha[idx, t] = prevSum + probs[labels[idx], t]
84+
end
85+
idx += blockDim().x
86+
end
87+
sync_threads()
88+
end
89+
return nothing
90+
end
91+
92+
function compute_beta_and_grad_kernel(probs, labelSize, uttLength,
93+
repeatsInLabel, labelsWithBlanks,
94+
alphas, beta, output, accum,
95+
grad, blankLabel, loss)
96+
97+
tid = threadIdx().x
98+
L = labelSize
99+
T = uttLength
100+
S = 2*L + 1
101+
repeats = repeatsInLabel
102+
labels = labelsWithBlanks
103+
104+
if (L+repeats) > T
105+
return nothing
106+
end
107+
108+
# Corner-case checking
109+
start = S > 1 ? S-2 : 0
110+
last = L + repeats < T ? S : S-1
111+
sync_threads()
112+
i = tid
113+
114+
# Calculate coefficients for last column (time step)
115+
# then determine alpha and beta product
116+
while i <= last - start
117+
beta[i+start, T] = 0
118+
output[i+start, T] = beta[i+start, T] + alphas[i+start, T]
119+
i += blockDim().x
120+
end
121+
sync_threads()
122+
123+
# Fill in `accum` for last column (time step)
124+
if tid == 1
125+
for i=1:S
126+
labelIdx = labels[i]
127+
accum[labelIdx, T] = log_plus_f(accum[labelIdx, T], output[i, T])
128+
end
129+
end
130+
sync_threads()
131+
132+
# Fill in `grad` for last column (time step)
133+
idx = tid
134+
while idx <= size(grad, 1)
135+
s = -Inf32
136+
for i=1:S
137+
s = log_plus_f(s, output[i, T])
138+
end
139+
140+
# ∂L/∂a (where a is activation before logsoftmax)
141+
grad[idx, T] = exp(probs[idx, T]) - exp(accum[idx, T] - s)
142+
idx += blockDim().x
143+
end
144+
sync_threads()
145+
146+
# Fill in the rest of the coefficients
147+
t = T-1
148+
while t >= 1
149+
if t < T
150+
idx = tid
151+
while idx <= S
152+
nextSum = probs[labels[idx], t+1] + beta[idx, t+1]
153+
if idx < S
154+
nextSum = log_plus_f(nextSum,
155+
probs[labels[idx+1], t+1] + beta[idx+1, t+1])
156+
end
157+
if labels[idx] != blankLabel && idx != S-1 && labels[idx] != labels[idx+2]
158+
nextSum = log_plus_f(nextSum,
159+
probs[labels[idx+2], t+1] + beta[idx + 2, t+1])
160+
end
161+
if idx > 2*t
162+
beta[idx, t] = -Inf32
163+
else
164+
beta[idx, t] = nextSum
165+
end
166+
idx += blockDim().x
167+
end
168+
sync_threads()
169+
idx = tid
170+
while idx <= S
171+
output[idx, t] = alphas[idx, t] + beta[idx, t]
172+
idx += blockDim().x
173+
end
174+
sync_threads()
175+
end
176+
sync_threads()
177+
178+
# Calculate accumulated alpha-beta products for each label class for
179+
# each time step; used in calculating gradients
180+
if tid == 1
181+
for i=1:S
182+
labelIdx = labels[i]
183+
accum[labelIdx, t] = log_plus_f(accum[labelIdx, t], output[i, t])
184+
end
185+
end
186+
sync_threads()
187+
idx = tid
188+
189+
# Calculate gradients
190+
while idx <= size(grad, 1)
191+
192+
# ∂L/∂a (where a is activation before logsoftmax)
193+
grad[idx, t] = exp(probs[idx, t]) - exp(accum[idx, t] + loss)
194+
idx += blockDim().x
195+
end
196+
sync_threads()
197+
t -= 1
198+
sync_threads()
199+
end
200+
return nothing
201+
end
202+
203+
function ctc_alpha(ŷ::CuArray, y)
204+
= logsoftmax(ŷ)
205+
blank = size(ŷ, 1)
206+
ycu = cu(y)
207+
z′ = CUDA.fill(blank, 2 * length(y) + 1)
208+
z′[eachindex(y) .* 2] .= ycu
209+
T = size(ŷ, 2)
210+
U′ = 2*length(y) + 1
211+
alphas = CUDA.fill(log(zero(eltype(ŷ))), U′,T)
212+
nRepeats = count_repeats(CUDA.adapt(Array, y))
213+
nThreads = min(U′, MAX_THREADS)
214+
@cuda blocks=1 threads=nThreads compute_alpha_kernel(ŷ, length(y), T, nRepeats, ycu, z′, alphas, blank)
215+
return (loss=-1 * logsumexp(alphas[end-1:end]), alpha=alphas, z′=z′, yhat=ŷ, nRepeats=nRepeats)
216+
end
217+
218+
ctc_loss(ŷ::CuArray, y) = ctc_alpha(ŷ::CuArray, y).loss
219+
220+
function ∇ctc_loss(ŷ::CuArray, y, out)
221+
loss, alphas, z′, ŷ, nRepeats = out
222+
U′, T = size(alphas)
223+
blank = size(ŷ, 1)
224+
typed_zero = zero(eltype(ŷ))
225+
betas = CUDA.fill(log(typed_zero), U′, T)
226+
output = CUDA.fill(log(typed_zero), U′, T)
227+
nThreads = min(U′, MAX_THREADS)
228+
grads = CUDA.fill(log(typed_zero), size(ŷ))
229+
accum = CUDA.fill(log(typed_zero), size(ŷ))
230+
@cuda blocks=1 threads=nThreads compute_beta_and_grad_kernel(ŷ, length(y), T, nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank, loss)
231+
return grads
232+
end

ext/NNlibCUDA/src/cudnn/activations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ for (f, op) in [
1919
@eval begin
2020
# in-place
2121
function Base.materialize!(dst::DenseCuArray{<:CUDNNFloat},
22-
bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{DenseCuArray}})
22+
bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{DenseCuArray{<:CUDNNFloat}}})
2323
$op(bc.args[1], dst)
2424
return dst
2525
end
2626

2727
# out of place
28-
function Base.materialize(bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{DenseCuArray}})
28+
function Base.materialize(bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{DenseCuArray{<:CUDNNFloat}}})
2929
ElType = Broadcast.combine_eltypes(bc.f, bc.args)
3030
dst = similar(bc, ElType)
3131
$op(bc.args[1], dst)

ext/NNlibCUDA/src/cudnn/conv.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ function ∇conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray
9797
alpha, beta = scalingParameter(T,alpha), scalingParameter(T,beta);
9898
convDesc, dx, depad = cudnnConvolutionDescriptorAndPaddedInput(cdims, dx)
9999
xDesc, yDesc, wDesc = cudnnTensorDescriptor(dx), cudnnTensorDescriptor(dy), cudnnFilterDescriptor(w)
100-
p = cudnnConvolutionBwdDataAlgoPerf(wDesc, w, yDesc, dy, convDesc, xDesc, dx)
100+
p = cudnnConvolutionBwdDataAlgoPerf(wDesc, w, yDesc, dy, convDesc, xDesc, dx, beta!=0)
101101
with_workspace(p.memory) do workspace
102102
cudnnConvolutionBackwardData(handle(), alpha, wDesc, w, yDesc, dy, convDesc, p.algo, workspace, sizeof(workspace), beta, xDesc, dx)
103103
end
@@ -115,7 +115,7 @@ function ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArr
115115
alpha, beta = scalingParameter(T,alpha), scalingParameter(T,beta);
116116
convDesc, x, _ = cudnnConvolutionDescriptorAndPaddedInput(cdims, x)
117117
xDesc, yDesc, wDesc = cudnnTensorDescriptor(x), cudnnTensorDescriptor(dy), cudnnFilterDescriptor(dw)
118-
p = cudnnConvolutionBwdFilterAlgoPerf(xDesc, x, yDesc, dy, convDesc, wDesc, dw);
118+
p = cudnnConvolutionBwdFilterAlgoPerf(xDesc, x, yDesc, dy, convDesc, wDesc, dw, beta!=0);
119119
with_workspace(p.memory) do workspace
120120
cudnnConvolutionBackwardFilter(handle(), alpha, xDesc, x, yDesc, dy, convDesc, p.algo, workspace, sizeof(workspace), beta, wDesc, dw);
121121
end

0 commit comments

Comments
 (0)