Skip to content

Commit f84c02c

Browse files
authored
Move ctc_loss from Flux to NNlibCUDA (#55)
* move ctc loss from Flux * fixup * trivial * rm cpu
1 parent 7953c92 commit f84c02c

File tree

5 files changed

+291
-2
lines changed

5 files changed

+291
-2
lines changed

ext/NNlibCUDA/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "NNlibCUDA"
22
uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d"
3-
version = "0.2.3"
3+
version = "0.2.4"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1313
[compat]
1414
Adapt = "3.3"
1515
CUDA = "3.11"
16-
NNlib = "0.8.7"
16+
NNlib = "0.8.9"
1717
julia = "1.6"
1818

1919
[extras]

ext/NNlibCUDA/src/NNlibCUDA.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ include("sampling.jl")
1111
include("activations.jl")
1212
include("batchedadjtrans.jl")
1313
include("batchedmul.jl")
14+
include("ctc.jl")
1415
include("scatter.jl")
1516
include("gather.jl")
1617
include("utils.jl")

ext/NNlibCUDA/src/ctc.jl

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# CTC loss moved from Flux.jl to NNlib + NNlibCUDA
2+
3+
import NNlib: ctc_loss, ctc_alpha, ∇ctc_loss
4+
5+
## GPU implementation
6+
7+
# a port of the GPU kernels from Baidu's C++ warp-ctc package,
8+
# which itself is Copyright 2015-2016 Baidu USA LLC
9+
# and available under the Apache 2.0 license
10+
#
11+
# Apache 2.0 license: https://www.apache.org/licenses/LICENSE-2.0
12+
# GitHub: https://github.com/baidu-research/warp-ctc/
13+
# paper: https://arxiv.org/pdf/1512.02595.pdf
14+
15+
const MAX_THREADS = 256
16+
17+
function log_plus_f(p1, p2)
18+
isinf(p1) && return p2
19+
isinf(p2) && return p1
20+
if p1 < p2
21+
p1, p2 = p2, p1
22+
end
23+
return p1 + log(1+exp(p2 - p1))
24+
end
25+
26+
function count_repeats(A)
27+
repeats = 0
28+
for (i,elem) in enumerate(A)
29+
if i > 1 && A[i] == A[i-1]
30+
repeats += 1
31+
end
32+
end
33+
return repeats
34+
end
35+
36+
function compute_alpha_kernel(probs, labelSize, uttLength, repeats, labelsWithoutBlanks, labelsWithBlanks, alpha, blankLabel)
37+
38+
tid = threadIdx().x
39+
L = labelSize
40+
T = uttLength
41+
S = length(labelsWithBlanks)
42+
43+
if L + repeats > T
44+
return nothing
45+
end
46+
labels = labelsWithBlanks
47+
48+
# Corner-case checking
49+
start = (L + repeats <= T) ? 0 : 1
50+
last = S > 1 ? 2 : 1
51+
52+
# Fill in first column (time step)
53+
i = tid
54+
while i <= last - start
55+
alpha[start+i, 1] = probs[labels[start+i], 1]
56+
i += blockDim().x
57+
end
58+
sync_threads()
59+
60+
# Fill in coefficients for each time step
61+
for t=2:T
62+
# Corner-case checking
63+
if tid == 1 && !(1 < S - 2*(T-t) - 1)
64+
if start == 0
65+
alpha[1, t] = alpha[1, t-1] + probs[blankLabel, t]
66+
elseif start == 1
67+
alpha[1, t] = alpha[1, t-1]
68+
end
69+
end
70+
sync_threads()
71+
72+
# Fill in coefficients for each label class in the target output sequence;
73+
# each thread will process the calculations for one class
74+
idx = tid+1
75+
while idx <= S
76+
prevSum = log_plus_f(alpha[idx, t-1], alpha[idx-1, t-1])
77+
if labels[idx] != blankLabel && idx != 2 && labels[idx] != labels[idx-2]
78+
prevSum = log_plus_f(prevSum, alpha[idx-2, t-1])
79+
end
80+
if idx < S - 2*(T-t) - 1
81+
alpha[idx, t] = -Inf32
82+
else
83+
alpha[idx, t] = prevSum + probs[labels[idx], t]
84+
end
85+
idx += blockDim().x
86+
end
87+
sync_threads()
88+
end
89+
return nothing
90+
end
91+
92+
function compute_beta_and_grad_kernel(probs, labelSize, uttLength,
93+
repeatsInLabel, labelsWithBlanks,
94+
alphas, beta, output, accum,
95+
grad, blankLabel, loss)
96+
97+
tid = threadIdx().x
98+
L = labelSize
99+
T = uttLength
100+
S = 2*L + 1
101+
repeats = repeatsInLabel
102+
labels = labelsWithBlanks
103+
104+
if (L+repeats) > T
105+
return nothing
106+
end
107+
108+
# Corner-case checking
109+
start = S > 1 ? S-2 : 0
110+
last = L + repeats < T ? S : S-1
111+
sync_threads()
112+
i = tid
113+
114+
# Calculate coefficients for last column (time step)
115+
# then determine alpha and beta product
116+
while i <= last - start
117+
beta[i+start, T] = 0
118+
output[i+start, T] = beta[i+start, T] + alphas[i+start, T]
119+
i += blockDim().x
120+
end
121+
sync_threads()
122+
123+
# Fill in `accum` for last column (time step)
124+
if tid == 1
125+
for i=1:S
126+
labelIdx = labels[i]
127+
accum[labelIdx, T] = log_plus_f(accum[labelIdx, T], output[i, T])
128+
end
129+
end
130+
sync_threads()
131+
132+
# Fill in `grad` for last column (time step)
133+
idx = tid
134+
while idx <= size(grad, 1)
135+
s = -Inf32
136+
for i=1:S
137+
s = log_plus_f(s, output[i, T])
138+
end
139+
140+
# ∂L/∂a (where a is activation before logsoftmax)
141+
grad[idx, T] = exp(probs[idx, T]) - exp(accum[idx, T] - s)
142+
idx += blockDim().x
143+
end
144+
sync_threads()
145+
146+
# Fill in the rest of the coefficients
147+
t = T-1
148+
while t >= 1
149+
if t < T
150+
idx = tid
151+
while idx <= S
152+
nextSum = probs[labels[idx], t+1] + beta[idx, t+1]
153+
if idx < S
154+
nextSum = log_plus_f(nextSum,
155+
probs[labels[idx+1], t+1] + beta[idx+1, t+1])
156+
end
157+
if labels[idx] != blankLabel && idx != S-1 && labels[idx] != labels[idx+2]
158+
nextSum = log_plus_f(nextSum,
159+
probs[labels[idx+2], t+1] + beta[idx + 2, t+1])
160+
end
161+
if idx > 2*t
162+
beta[idx, t] = -Inf32
163+
else
164+
beta[idx, t] = nextSum
165+
end
166+
idx += blockDim().x
167+
end
168+
sync_threads()
169+
idx = tid
170+
while idx <= S
171+
output[idx, t] = alphas[idx, t] + beta[idx, t]
172+
idx += blockDim().x
173+
end
174+
sync_threads()
175+
end
176+
sync_threads()
177+
178+
# Calculate accumulated alpha-beta products for each label class for
179+
# each time step; used in calculating gradients
180+
if tid == 1
181+
for i=1:S
182+
labelIdx = labels[i]
183+
accum[labelIdx, t] = log_plus_f(accum[labelIdx, t], output[i, t])
184+
end
185+
end
186+
sync_threads()
187+
idx = tid
188+
189+
# Calculate gradients
190+
while idx <= size(grad, 1)
191+
192+
# ∂L/∂a (where a is activation before logsoftmax)
193+
grad[idx, t] = exp(probs[idx, t]) - exp(accum[idx, t] + loss)
194+
idx += blockDim().x
195+
end
196+
sync_threads()
197+
t -= 1
198+
sync_threads()
199+
end
200+
return nothing
201+
end
202+
203+
function ctc_alpha(ŷ::CuArray, y)
204+
= logsoftmax(ŷ)
205+
blank = size(ŷ, 1)
206+
ycu = cu(y)
207+
z′ = CUDA.fill(blank, 2 * length(y) + 1)
208+
z′[eachindex(y) .* 2] .= ycu
209+
T = size(ŷ, 2)
210+
U′ = 2*length(y) + 1
211+
alphas = CUDA.fill(log(zero(eltype(ŷ))), U′,T)
212+
nRepeats = count_repeats(CUDA.adapt(Array, y))
213+
nThreads = min(U′, MAX_THREADS)
214+
@cuda blocks=1 threads=nThreads compute_alpha_kernel(ŷ, length(y), T, nRepeats, ycu, z′, alphas, blank)
215+
return (loss=-1 * logsumexp(alphas[end-1:end]), alpha=alphas, z′=z′, yhat=ŷ, nRepeats=nRepeats)
216+
end
217+
218+
ctc_loss(ŷ::CuArray, y) = ctc_alpha(ŷ::CuArray, y).loss
219+
220+
function ∇ctc_loss(ŷ::CuArray, y, out)
221+
loss, alphas, z′, ŷ, nRepeats = out
222+
U′, T = size(alphas)
223+
blank = size(ŷ, 1)
224+
typed_zero = zero(eltype(ŷ))
225+
betas = CUDA.fill(log(typed_zero), U′, T)
226+
output = CUDA.fill(log(typed_zero), U′, T)
227+
nThreads = min(U′, MAX_THREADS)
228+
grads = CUDA.fill(log(typed_zero), size(ŷ))
229+
accum = CUDA.fill(log(typed_zero), size(ŷ))
230+
@cuda blocks=1 threads=nThreads compute_beta_and_grad_kernel(ŷ, length(y), T, nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank, loss)
231+
return grads
232+
end

ext/NNlibCUDA/test/ctc.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
using Test
2+
using NNlib: ctc_loss
3+
using Zygote: gradient
4+
using LinearAlgebra
5+
using CUDA, NNlibCUDA
6+
7+
# Custom function to check numerical gradient of ctc loss,
8+
# based on `ngradient` in `Tracker.jl`
9+
function ctc_ngradient(x, y)
10+
f = ctc_loss
11+
grads = zero(x)
12+
for i in 1:length(x)
13+
δ = sqrt(eps())
14+
tmp = x[i]
15+
x[i] = tmp - δ/2
16+
y1 = f(x, y)
17+
x[i] = tmp + δ/2
18+
y2 = f(x, y)
19+
x[i] = tmp
20+
grads[i] = (y2-y1)/δ
21+
end
22+
return grads
23+
end
24+
25+
@testset "ctc-gpu" begin
26+
x = rand(10, 50)
27+
y = rand(1:9, 30)
28+
x_cu = CuArray(x)
29+
g1 = gradient(ctc_loss, x_cu, y)[1]
30+
g1 = g1 |> collect
31+
g2 = ctc_ngradient(x, y)
32+
@test g1 g2 rtol=1e-5 atol=1e-5
33+
34+
# test that GPU loss matches CPU implementation
35+
l1 = ctc_loss(x_cu, y)
36+
l2 = ctc_loss(x, y)
37+
@test l1 l2
38+
39+
# tests using hand-calculated values
40+
x_cu = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.] |> CuArray
41+
y = [1, 2]
42+
@test ctc_loss(x_cu, y) 3.6990738275138035
43+
44+
g = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457]
45+
ghat = gradient(ctc_loss, x_cu, y)[1] |> collect
46+
@test g ghat rtol=1e-5 atol=1e-5
47+
48+
x_cu = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.] |> CuArray
49+
y = [1, 2] |> CuArray
50+
@test ctc_loss(x_cu, y) 8.02519869363453
51+
52+
g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07]
53+
ghat = gradient(ctc_loss, x_cu, y)[1] |> collect
54+
@test g ghat rtol=1e-5 atol=1e-5
55+
end

ext/NNlibCUDA/test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ include("batchedadjtrans.jl")
1414
include("batchedmul.jl")
1515
include("upsample.jl")
1616
include("conv.jl")
17+
include("ctc.jl")
1718
include("pooling.jl")
1819
include("softmax.jl")
1920
include("batchnorm.jl")

0 commit comments

Comments
 (0)