Skip to content

Commit dd28321

Browse files
bors[bot]maetshju
andauthored
Merge #1287
1287: Add CTC loss to new Losses module r=CarloLucibello a=maetshju This is a redux of adding the connectionist temporal classification loss from #342, now that the Losses module has been merged in #1264. Discussion in #342 suggested that a new PR would be easier than rebasing. Since the last commit in #342, functions and data structures from `CUDAnative.jl` and `CuArrays.jl` have been updated to work with `CUDA.jl`. This is in addition to incorporating the loss function into the Losses module. ### PR Checklist - [X] Tests are added - [X] Entry in NEWS.md - [X] Documentation, if applicable - [ ] Final review from `@dhairyagandhi96` (for API changes). Co-authored-by: Matt Kelley <[email protected]> Co-authored-by: Matthew C. Kelley <[email protected]>
2 parents 02ea511 + bc94a16 commit dd28321

File tree

7 files changed

+482
-2
lines changed

7 files changed

+482
-2
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* Dense and Conv layers no longer perform [implicit type conversion](https://github.com/FluxML/Flux.jl/pull/1394).
77
* Excise datasets in favour of other providers in the julia ecosystem.
88
* Added option to set `bias` to [false](https://github.com/FluxML/Flux.jl/pull/1379) to eliminating `bias` from being trained.
9+
* Add [CTC loss function](https://github.com/FluxML/Flux.jl/pull/1287) to Losses module
910
* Removed kwarg only constructors for [`convolutional layers`](https://github.com/FluxML/Flux.jl/pull/1379).
1011
* Add [sparse initialization](https://github.com/FluxML/Flux.jl/pull/1454) as described in [Deep learning via Hessian-free optimization](https://dl.acm.org/doi/abs/10.5555/3104322.3104416).
1112
* Moved GPU CI to use buildkite instead of GitLab

src/losses/Losses.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@ export mse, mae, msle,
1717
tversky_loss,
1818
dice_coeff_loss,
1919
poisson_loss,
20-
hinge_loss, squared_hinge_loss
20+
hinge_loss, squared_hinge_loss,
21+
ctc_loss
2122

2223
include("utils.jl")
2324
include("functions.jl")
25+
include("ctc.jl")
26+
if CUDA.functional() include("ctc-gpu.jl") end
2427

25-
end #module
28+
end #module

src/losses/ctc-gpu.jl

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# GPU implementation
2+
3+
# a port of the GPU kernels from Baidu's C++ warp-ctc package,
4+
# which itself is Copyright 2015-2016 Baidu USA LLC
5+
# and available under the Apache 2.0 license
6+
#
7+
# Apache 2.0 license: https://www.apache.org/licenses/LICENSE-2.0
8+
# GitHub: https://github.com/baidu-research/warp-ctc/
9+
# paper: https://arxiv.org/pdf/1512.02595.pdf
10+
11+
using Flux
12+
using Statistics
13+
using CUDA
14+
using NNlib
15+
16+
const MAX_THREADS = 256
17+
18+
function log_plus_f(p1, p2)
19+
isinf(p1) && return p2
20+
isinf(p2) && return p1
21+
if p1 < p2
22+
p1, p2 = p2, p1
23+
end
24+
return p1 + CUDA.log(1+CUDA.exp(p2 - p1))
25+
end
26+
27+
function count_repeats(A)
28+
repeats = 0
29+
for (i,elem) in enumerate(A)
30+
if i > 1 && A[i] == A[i-1]
31+
repeats += 1
32+
end
33+
end
34+
return repeats
35+
end
36+
37+
function compute_alpha_kernel(probs, labelSize, uttLength, repeats, labelsWithoutBlanks, labelsWithBlanks, alpha, blankLabel)
38+
39+
tid = threadIdx().x
40+
L = labelSize
41+
T = uttLength
42+
S = length(labelsWithBlanks)
43+
44+
if L + repeats > T
45+
return nothing
46+
end
47+
labels = labelsWithBlanks
48+
49+
# Corner-case checking
50+
start = (L + repeats <= T) ? 0 : 1
51+
last = S > 1 ? 2 : 1
52+
53+
# Fill in first column (time step)
54+
i = tid
55+
while i <= last - start
56+
alpha[start+i, 1] = probs[labels[start+i], 1]
57+
i += blockDim().x
58+
end
59+
sync_threads()
60+
61+
# Fill in coefficients for each time step
62+
for t=2:T
63+
# Corner-case checking
64+
if tid == 1 && !(1 < S - 2*(T-t) - 1)
65+
if start == 0
66+
alpha[1, t] = alpha[1, t-1] + probs[blankLabel, t]
67+
elseif start == 1
68+
alpha[1, t] = alpha[1, t-1]
69+
end
70+
end
71+
sync_threads()
72+
73+
# Fill in coefficients for each label class in the target output sequence;
74+
# each thread will process the calculations for one class
75+
idx = tid+1
76+
while idx <= S
77+
prevSum = log_plus_f(alpha[idx, t-1], alpha[idx-1, t-1])
78+
if labels[idx] != blankLabel && idx != 2 && labels[idx] != labels[idx-2]
79+
prevSum = log_plus_f(prevSum, alpha[idx-2, t-1])
80+
end
81+
if idx < S - 2*(T-t) - 1
82+
alpha[idx, t] = -Inf32
83+
else
84+
alpha[idx, t] = prevSum + probs[labels[idx], t]
85+
end
86+
idx += blockDim().x
87+
end
88+
sync_threads()
89+
end
90+
return nothing
91+
end
92+
93+
function compute_beta_and_grad_kernel(probs, labelSize, uttLength,
94+
repeatsInLabel, labelsWithBlanks,
95+
alphas, beta, output, accum,
96+
grad, blankLabel, loss)
97+
98+
tid = threadIdx().x
99+
L = labelSize
100+
T = uttLength
101+
S = 2*L + 1
102+
repeats = repeatsInLabel
103+
labels = labelsWithBlanks
104+
105+
if (L+repeats) > T
106+
return nothing
107+
end
108+
109+
# Corner-case checking
110+
start = S > 1 ? S-2 : 0
111+
last = L + repeats < T ? S : S-1
112+
sync_threads()
113+
i = tid
114+
115+
# Calculate coefficients for last column (time step)
116+
# then determine alpha and beta product
117+
while i <= last - start
118+
beta[i+start, T] = 0
119+
output[i+start, T] = beta[i+start, T] + alphas[i+start, T]
120+
i += blockDim().x
121+
end
122+
sync_threads()
123+
124+
# Fill in `accum` for last column (time step)
125+
if tid == 1
126+
for i=1:S
127+
labelIdx = labels[i]
128+
accum[labelIdx, T] = log_plus_f(accum[labelIdx, T], output[i, T])
129+
end
130+
end
131+
sync_threads()
132+
133+
# Fill in `grad` for last column (time step)
134+
idx = tid
135+
while idx <= size(grad, 1)
136+
s = -Inf32
137+
for i=1:S
138+
s = log_plus_f(s, output[i, T])
139+
end
140+
141+
# ∂L/∂a (where a is activation before logsoftmax)
142+
grad[idx, T] = CUDA.exp(probs[idx, T]) - CUDA.exp(accum[idx, T] - s)
143+
idx += blockDim().x
144+
end
145+
sync_threads()
146+
147+
# Fill in the rest of the coefficients
148+
t = T-1
149+
while t >= 1
150+
if t < T
151+
idx = tid
152+
while idx <= S
153+
nextSum = probs[labels[idx], t+1] + beta[idx, t+1]
154+
if idx < S
155+
nextSum = log_plus_f(nextSum,
156+
probs[labels[idx+1], t+1] + beta[idx+1, t+1])
157+
end
158+
if labels[idx] != blankLabel && idx != S-1 && labels[idx] != labels[idx+2]
159+
nextSum = log_plus_f(nextSum,
160+
probs[labels[idx+2], t+1] + beta[idx + 2, t+1])
161+
end
162+
if idx > 2*t
163+
beta[idx, t] = -Inf32
164+
else
165+
beta[idx, t] = nextSum
166+
end
167+
idx += blockDim().x
168+
end
169+
sync_threads()
170+
idx = tid
171+
while idx <= S
172+
output[idx, t] = alphas[idx, t] + beta[idx, t]
173+
idx += blockDim().x
174+
end
175+
sync_threads()
176+
end
177+
sync_threads()
178+
179+
# Calculate accumulated alpha-beta products for each label class for
180+
# each time step; used in calculating gradients
181+
if tid == 1
182+
for i=1:S
183+
labelIdx = labels[i]
184+
accum[labelIdx, t] = log_plus_f(accum[labelIdx, t], output[i, t])
185+
end
186+
end
187+
sync_threads()
188+
idx = tid
189+
190+
# Calculate gradients
191+
while idx <= size(grad, 1)
192+
193+
# ∂L/∂a (where a is activation before logsoftmax)
194+
grad[idx, t] = CUDA.exp(probs[idx, t]) - CUDA.exp(accum[idx, t] + loss)
195+
idx += blockDim().x
196+
end
197+
sync_threads()
198+
t -= 1
199+
sync_threads()
200+
end
201+
return nothing
202+
end
203+
204+
function ctc_alpha(ŷ::CuArray, y)
205+
= logsoftmax(ŷ)
206+
blank = size(ŷ, 1)
207+
z′ = fill(blank, 2 * length(y) + 1)
208+
z′[eachindex(y) .* 2] = y
209+
T = size(ŷ, 2)
210+
U′ = 2*length(y) + 1
211+
alphas = CUDA.fill(log(zero(ŷ[1])), U′,T)
212+
nRepeats = count_repeats(y)
213+
nThreads = min(U′, MAX_THREADS)
214+
@cuda blocks=1 threads=nThreads compute_alpha_kernel(ŷ, length(y), T, nRepeats, CuArray(y), CuArray(z′), alphas, blank)
215+
return (loss=-1 * logsumexp(alphas[end-1:end]), alpha=alphas, z′=z′, yhat=ŷ, nRepeats=nRepeats)
216+
end
217+
218+
ctc_loss(ŷ::CuArray, y) = ctc_alpha(ŷ::CuArray, y).loss
219+
220+
function ∇ctc_loss(ŷ::CuArray, y, out)
221+
loss, alphas, z′, ŷ, nRepeats = out
222+
U′, T = size(alphas)
223+
blank = size(ŷ, 1)
224+
typed_zero = zero(first(ŷ))
225+
betas = CUDA.fill(log(typed_zero), U′, T)
226+
output = CUDA.fill(log(typed_zero), U′, T)
227+
nThreads = min(U′, MAX_THREADS)
228+
grads = CUDA.fill(log(typed_zero), size(ŷ))
229+
accum = CUDA.fill(log(typed_zero), size(ŷ))
230+
@cuda blocks=1 threads=nThreads compute_beta_and_grad_kernel(ŷ, length(y), T, nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank, loss)
231+
return grads
232+
end

src/losses/ctc.jl

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
using Flux
2+
using Zygote: @adjoint
3+
using Statistics
4+
using NNlib
5+
6+
# CPU implementation
7+
"""
8+
logaddexp(a, b)
9+
Adds log-space `a` and `b` such that the result equals `log(exp(a)+exp(b))`
10+
"""
11+
function logaddexp(a, b)
12+
isinf(a) && return b
13+
isinf(b) && return a
14+
15+
# always want the greater number on the left in the exponentiation;
16+
# the magnitude difference may end up making the number very positive
17+
# which will cause exp() to return Inf
18+
# E.g., a = -900, b = -800, will give exp(-800 - -900), which will be
19+
# Inf for Float32 values
20+
if a < b
21+
a, b = b, a
22+
end
23+
return a + log(1+exp(b-a))
24+
end
25+
26+
"""
27+
add_blanks(z)
28+
29+
Adds blanks to the start and end of `z`, and between items in `z`
30+
"""
31+
function add_blanks(z, blank)
32+
z′ = fill(blank, 2*length(z) + 1)
33+
z′[2 .* eachindex(z)] = z
34+
return z′
35+
end
36+
37+
function ctc_alpha(ŷ::AbstractArray, y)
38+
typed_zero = zero(ŷ[1])
39+
= logsoftmax(ŷ)
40+
blank = size(ŷ, 1)
41+
z′ = add_blanks(y, blank)
42+
T = size(ŷ, 2)
43+
U′ = length(z′)
44+
45+
α = fill(log(typed_zero), U′, T)
46+
α[1,1] = ŷ[blank, 1]
47+
α[2,1] = ŷ[z′[2], 1]
48+
for t=2:T
49+
bound = max(1, U′ - 2(T - t) - 1)
50+
for u=bound:U′
51+
if u == 1
52+
α[u,t] = α[u, t-1]
53+
else
54+
α[u,t] = logaddexp(α[u, t-1], α[u-1, t-1])
55+
56+
# array bounds check and f(u) function from Eq. 7.9
57+
if u > 2 && !(z′[u] == blank || z′[u-2] == z′[u])
58+
α[u,t] = logaddexp(α[u,t], α[u-2,t-1])
59+
end
60+
end
61+
α[u,t] += ŷ[z′[u], t]
62+
end
63+
end
64+
return (loss=-1 * logaddexp(α[end,T], α[end-1, T]), alpha=α, zprime=z′, logsoftyhat=ŷ)
65+
end
66+
67+
function ∇ctc_loss(ŷ::AbstractArray, y, out)
68+
loss, α, z′, ŷ = out
69+
U′, T = size(α)
70+
blank = size(ŷ, 1)
71+
typed_zero = zero(first(α))
72+
73+
# Calculate beta coefficients, from the bottom-right, to the upper-left
74+
β = fill(log(typed_zero), U′, T)
75+
76+
# Fill bottom-right corner so bounding errors can be avoided
77+
# by starting `u` at `U′-1`
78+
β[U′, T] = typed_zero
79+
β[U′-1, T] = typed_zero
80+
81+
# start at T-1 so that β(T, u) = log(0) for all u < U′ - 1
82+
for t=(T-1):-1:1
83+
bound = min(U′, 2t)
84+
for u=bound:-1:1
85+
if u == U′
86+
β[u,t] = ŷ[z′[u], t+1] + β[u, t+1]
87+
else
88+
β[u,t] = logaddexp(ŷ[z′[u], t+1] + β[u, t+1], ŷ[z′[u+1], t+1] + β[u+1,t+1])
89+
90+
# array bounds check and g(u) function from Eq. 7.16
91+
if u+2 <= U′ && z′[u] != blank && z′[u] != z′[u+2]
92+
β[u,t] = logaddexp(β[u,t], ŷ[z′[u+2], t+1] + β[u+2, t+1])
93+
end
94+
end
95+
end
96+
end
97+
98+
# Accumulate alpha-beta products for each category,
99+
# then calculate gradients
100+
accum = fill(log(typed_zero), size(ŷ))
101+
for t=1:T
102+
for u=1:U′
103+
accum[z′[u], t] = logaddexp(accum[z′[u], t], α[u,t] + β[u,t])
104+
end
105+
end
106+
grads = exp.(ŷ) .- exp.(accum .+ loss)
107+
return grads
108+
end
109+
110+
"""
111+
ctc_loss(ŷ, y)
112+
113+
Computes the connectionist temporal classification loss between `ŷ`
114+
and `y`.
115+
116+
`ŷ` must be a classes-by-time matrices, i.e., each row
117+
represents a class and each column represents a time step.
118+
Additionally, the `logsoftmax` function will be applied to `ŷ`, so
119+
`ŷ` must be the raw activation values from the neural network and
120+
not, for example, the activations after being passed through a
121+
`softmax` activation function. `y` must be a 1D array of the labels
122+
associated with `ŷ`. The blank label is assumed to be the last label
123+
category in `ŷ`, so it is equivalent to `size(ŷ, 1)`.
124+
125+
Used for sequence-to-sequence classification problems such as
126+
speech recognition and handwriting recognition where the exact
127+
time-alignment of the output (e.g., letters) is not needed to
128+
solve the problem. See [Graves et al. (2006)](https://www.cs.toronto.edu/~graves/icml_2006.pdf)
129+
or [Graves (2012)](https://www.cs.toronto.edu/~graves/preprint.pdf#chapter.7)
130+
for mathematical details.
131+
"""
132+
ctc_loss(ŷ::AbstractArray, y) = ctc_alpha(ŷ, y).loss
133+
134+
@adjoint function ctc_loss(ŷ, y)
135+
out = ctc_alpha(ŷ, y)
136+
ctc_loss_pullback(Δ) =.* ∇ctc_loss(ŷ, y, out), nothing)
137+
return out.loss, ctc_loss_pullback
138+
end

0 commit comments

Comments
 (0)