Skip to content

Commit 023cd3d

Browse files
authored
Move ctc_loss from Flux to NNlib (#426)
* move ctc loss from Flux * fixup
1 parent c9faa64 commit 023cd3d

File tree

5 files changed

+187
-1
lines changed

5 files changed

+187
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "NNlib"
22
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
3-
version = "0.8.8"
3+
version = "0.8.9"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/NNlib.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter,
6161
include("conv_bias_act.jl")
6262
export conv_bias_act, conv_bias_act!
6363

64+
include("ctc.jl")
65+
export ctc_loss
66+
6467
include("pooling.jl")
6568
export maxpool, maxpool!, meanpool, meanpool!,
6669
∇maxpool, ∇maxpool!, ∇meanpool, ∇meanpool!

src/ctc.jl

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# CTC loss moved from Flux.jl to NNlib + NNlibCUDA
2+
3+
## CPU implementation
4+
5+
"""
6+
logaddexp(a, b)
7+
Adds log-space `a` and `b` such that the result equals `log(exp(a)+exp(b))`
8+
"""
9+
function logaddexp(a, b)
10+
isinf(a) && return b
11+
isinf(b) && return a
12+
13+
# always want the greater number on the left in the exponentiation;
14+
# the magnitude difference may end up making the number very positive
15+
# which will cause exp() to return Inf
16+
# E.g., a = -900, b = -800, will give exp(-800 - -900), which will be
17+
# Inf for Float32 values
18+
if a < b
19+
a, b = b, a
20+
end
21+
return a + log(1+exp(b-a))
22+
end
23+
24+
"""
25+
add_blanks(z)
26+
Adds blanks to the start and end of `z`, and between items in `z`
27+
"""
28+
function add_blanks(z, blank)
29+
z′ = fill(blank, 2*length(z) + 1)
30+
z′[2 .* eachindex(z)] = z
31+
return z′
32+
end
33+
34+
function ctc_alpha(ŷ::AbstractArray, y)
35+
typed_zero = zero(ŷ[1])
36+
= logsoftmax(ŷ)
37+
blank = size(ŷ, 1)
38+
z′ = add_blanks(y, blank)
39+
T = size(ŷ, 2)
40+
U′ = length(z′)
41+
42+
α = fill(log(typed_zero), U′, T)
43+
α[1,1] = ŷ[blank, 1]
44+
α[2,1] = ŷ[z′[2], 1]
45+
for t=2:T
46+
bound = max(1, U′ - 2(T - t) - 1)
47+
for u=bound:U′
48+
if u == 1
49+
α[u,t] = α[u, t-1]
50+
else
51+
α[u,t] = logaddexp(α[u, t-1], α[u-1, t-1])
52+
53+
# array bounds check and f(u) function from Eq. 7.9
54+
if u > 2 && !(z′[u] == blank || z′[u-2] == z′[u])
55+
α[u,t] = logaddexp(α[u,t], α[u-2,t-1])
56+
end
57+
end
58+
α[u,t] += ŷ[z′[u], t]
59+
end
60+
end
61+
return (loss=-1 * logaddexp(α[end,T], α[end-1, T]), alpha=α, zprime=z′, logsoftyhat=ŷ)
62+
end
63+
64+
function ∇ctc_loss(ŷ::AbstractArray, y, out)
65+
loss, α, z′, ŷ = out
66+
U′, T = size(α)
67+
blank = size(ŷ, 1)
68+
typed_zero = zero(first(α))
69+
70+
# Calculate beta coefficients, from the bottom-right, to the upper-left
71+
β = fill(log(typed_zero), U′, T)
72+
73+
# Fill bottom-right corner so bounding errors can be avoided
74+
# by starting `u` at `U′-1`
75+
β[U′, T] = typed_zero
76+
β[U′-1, T] = typed_zero
77+
78+
# start at T-1 so that β(T, u) = log(0) for all u < U′ - 1
79+
for t=(T-1):-1:1
80+
bound = min(U′, 2t)
81+
for u=bound:-1:1
82+
if u == U′
83+
β[u,t] = ŷ[z′[u], t+1] + β[u, t+1]
84+
else
85+
β[u,t] = logaddexp(ŷ[z′[u], t+1] + β[u, t+1], ŷ[z′[u+1], t+1] + β[u+1,t+1])
86+
87+
# array bounds check and g(u) function from Eq. 7.16
88+
if u+2 <= U′ && z′[u] != blank && z′[u] != z′[u+2]
89+
β[u,t] = logaddexp(β[u,t], ŷ[z′[u+2], t+1] + β[u+2, t+1])
90+
end
91+
end
92+
end
93+
end
94+
95+
# Accumulate alpha-beta products for each category,
96+
# then calculate gradients
97+
accum = fill(log(typed_zero), size(ŷ))
98+
for t=1:T
99+
for u=1:U′
100+
accum[z′[u], t] = logaddexp(accum[z′[u], t], α[u,t] + β[u,t])
101+
end
102+
end
103+
grads = exp.(ŷ) .- exp.(accum .+ loss)
104+
return grads
105+
end
106+
107+
"""
108+
ctc_loss(ŷ, y)
109+
Computes the connectionist temporal classification loss between `ŷ`
110+
and `y`.
111+
`ŷ` must be a classes-by-time matrices, i.e., each row
112+
represents a class and each column represents a time step.
113+
Additionally, the `logsoftmax` function will be applied to `ŷ`, so
114+
`ŷ` must be the raw activation values from the neural network and
115+
not, for example, the activations after being passed through a
116+
`softmax` activation function. `y` must be a 1D array of the labels
117+
associated with `ŷ`. The blank label is assumed to be the last label
118+
category in `ŷ`, so it is equivalent to `size(ŷ, 1)`.
119+
Used for sequence-to-sequence classification problems such as
120+
speech recognition and handwriting recognition where the exact
121+
time-alignment of the output (e.g., letters) is not needed to
122+
solve the problem. See [Graves et al. (2006)](https://www.cs.toronto.edu/~graves/icml_2006.pdf)
123+
or [Graves (2012)](https://www.cs.toronto.edu/~graves/preprint.pdf#chapter.7)
124+
for mathematical details.
125+
"""
126+
ctc_loss(ŷ::AbstractArray, y) = ctc_alpha(ŷ, y).loss
127+
128+
function ChainRulesCore.rrule(::typeof(ctc_loss), ŷ, y)
129+
tmp = ctc_alpha(ŷ, y)
130+
ctc_loss_pullback(Δ) = (NoTangent(), Δ .* ∇ctc_loss(ŷ, y, tmp), NoTangent())
131+
return tmp.loss, ctc_loss_pullback
132+
end

test/ctc.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
using Test
2+
using NNlib: ctc_loss
3+
using Zygote: gradient
4+
using LinearAlgebra
5+
6+
# Custom function to check numerical gradient of ctc loss,
7+
# based on `ngradient` in `Tracker.jl`
8+
function ctc_ngradient(x, y)
9+
f = ctc_loss
10+
grads = zero(x)
11+
for i in 1:length(x)
12+
δ = sqrt(eps())
13+
tmp = x[i]
14+
x[i] = tmp - δ/2
15+
y1 = f(x, y)
16+
x[i] = tmp + δ/2
17+
y2 = f(x, y)
18+
x[i] = tmp
19+
grads[i] = (y2-y1)/δ
20+
end
21+
return grads
22+
end
23+
24+
@testset "ctc_loss" begin
25+
x = rand(10, 50)
26+
y = rand(1:9, 30)
27+
g1 = gradient(ctc_loss, x, y)[1]
28+
g2 = ctc_ngradient(x, y)
29+
@test g1 g2 rtol=1e-5 atol=1e-5
30+
31+
# tests using hand-calculated values
32+
x = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.]
33+
y = [1, 2]
34+
@test ctc_loss(x, y) 3.6990738275138035
35+
36+
g = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457]
37+
ghat = gradient(ctc_loss, x, y)[1]
38+
@test g ghat rtol=1e-5 atol=1e-5
39+
40+
x = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.]
41+
y = [1, 2]
42+
@test ctc_loss(x, y) 8.02519869363453
43+
44+
g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07]
45+
ghat = gradient(ctc_loss, x, y)[1]
46+
@test g ghat rtol=1e-5 atol=1e-5
47+
end

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ include("test_utils.jl")
4848
include("conv_bias_act.jl")
4949
end
5050

51+
@testset "CTC Loss" begin
52+
include("ctc.jl")
53+
end
54+
5155
@testset "Inference" begin
5256
include("inference.jl")
5357
end

0 commit comments

Comments
 (0)