1
+ # CTC loss moved from Flux.jl to NNlib + NNlibCUDA
2
+
3
+ import NNlib: ctc_loss, ctc_alpha, ∇ctc_loss
4
+
5
+ # # GPU implementation
6
+
7
+ # a port of the GPU kernels from Baidu's C++ warp-ctc package,
8
+ # which itself is Copyright 2015-2016 Baidu USA LLC
9
+ # and available under the Apache 2.0 license
10
+ #
11
+ # Apache 2.0 license: https://www.apache.org/licenses/LICENSE-2.0
12
+ # GitHub: https://github.com/baidu-research/warp-ctc/
13
+ # paper: https://arxiv.org/pdf/1512.02595.pdf
14
+
15
+ const MAX_THREADS = 256
16
+
17
+ function log_plus_f (p1, p2)
18
+ isinf (p1) && return p2
19
+ isinf (p2) && return p1
20
+ if p1 < p2
21
+ p1, p2 = p2, p1
22
+ end
23
+ return p1 + log (1 + exp (p2 - p1))
24
+ end
25
+
26
+ function count_repeats (A)
27
+ repeats = 0
28
+ for (i,elem) in enumerate (A)
29
+ if i > 1 && A[i] == A[i- 1 ]
30
+ repeats += 1
31
+ end
32
+ end
33
+ return repeats
34
+ end
35
+
36
+ function compute_alpha_kernel (probs, labelSize, uttLength, repeats, labelsWithoutBlanks, labelsWithBlanks, alpha, blankLabel)
37
+
38
+ tid = threadIdx (). x
39
+ L = labelSize
40
+ T = uttLength
41
+ S = length (labelsWithBlanks)
42
+
43
+ if L + repeats > T
44
+ return nothing
45
+ end
46
+ labels = labelsWithBlanks
47
+
48
+ # Corner-case checking
49
+ start = (L + repeats <= T) ? 0 : 1
50
+ last = S > 1 ? 2 : 1
51
+
52
+ # Fill in first column (time step)
53
+ i = tid
54
+ while i <= last - start
55
+ alpha[start+ i, 1 ] = probs[labels[start+ i], 1 ]
56
+ i += blockDim (). x
57
+ end
58
+ sync_threads ()
59
+
60
+ # Fill in coefficients for each time step
61
+ for t= 2 : T
62
+ # Corner-case checking
63
+ if tid == 1 && ! (1 < S - 2 * (T- t) - 1 )
64
+ if start == 0
65
+ alpha[1 , t] = alpha[1 , t- 1 ] + probs[blankLabel, t]
66
+ elseif start == 1
67
+ alpha[1 , t] = alpha[1 , t- 1 ]
68
+ end
69
+ end
70
+ sync_threads ()
71
+
72
+ # Fill in coefficients for each label class in the target output sequence;
73
+ # each thread will process the calculations for one class
74
+ idx = tid+ 1
75
+ while idx <= S
76
+ prevSum = log_plus_f (alpha[idx, t- 1 ], alpha[idx- 1 , t- 1 ])
77
+ if labels[idx] != blankLabel && idx != 2 && labels[idx] != labels[idx- 2 ]
78
+ prevSum = log_plus_f (prevSum, alpha[idx- 2 , t- 1 ])
79
+ end
80
+ if idx < S - 2 * (T- t) - 1
81
+ alpha[idx, t] = - Inf32
82
+ else
83
+ alpha[idx, t] = prevSum + probs[labels[idx], t]
84
+ end
85
+ idx += blockDim (). x
86
+ end
87
+ sync_threads ()
88
+ end
89
+ return nothing
90
+ end
91
+
92
+ function compute_beta_and_grad_kernel (probs, labelSize, uttLength,
93
+ repeatsInLabel, labelsWithBlanks,
94
+ alphas, beta, output, accum,
95
+ grad, blankLabel, loss)
96
+
97
+ tid = threadIdx (). x
98
+ L = labelSize
99
+ T = uttLength
100
+ S = 2 * L + 1
101
+ repeats = repeatsInLabel
102
+ labels = labelsWithBlanks
103
+
104
+ if (L+ repeats) > T
105
+ return nothing
106
+ end
107
+
108
+ # Corner-case checking
109
+ start = S > 1 ? S- 2 : 0
110
+ last = L + repeats < T ? S : S- 1
111
+ sync_threads ()
112
+ i = tid
113
+
114
+ # Calculate coefficients for last column (time step)
115
+ # then determine alpha and beta product
116
+ while i <= last - start
117
+ beta[i+ start, T] = 0
118
+ output[i+ start, T] = beta[i+ start, T] + alphas[i+ start, T]
119
+ i += blockDim (). x
120
+ end
121
+ sync_threads ()
122
+
123
+ # Fill in `accum` for last column (time step)
124
+ if tid == 1
125
+ for i= 1 : S
126
+ labelIdx = labels[i]
127
+ accum[labelIdx, T] = log_plus_f (accum[labelIdx, T], output[i, T])
128
+ end
129
+ end
130
+ sync_threads ()
131
+
132
+ # Fill in `grad` for last column (time step)
133
+ idx = tid
134
+ while idx <= size (grad, 1 )
135
+ s = - Inf32
136
+ for i= 1 : S
137
+ s = log_plus_f (s, output[i, T])
138
+ end
139
+
140
+ # ∂L/∂a (where a is activation before logsoftmax)
141
+ grad[idx, T] = exp (probs[idx, T]) - exp (accum[idx, T] - s)
142
+ idx += blockDim (). x
143
+ end
144
+ sync_threads ()
145
+
146
+ # Fill in the rest of the coefficients
147
+ t = T- 1
148
+ while t >= 1
149
+ if t < T
150
+ idx = tid
151
+ while idx <= S
152
+ nextSum = probs[labels[idx], t+ 1 ] + beta[idx, t+ 1 ]
153
+ if idx < S
154
+ nextSum = log_plus_f (nextSum,
155
+ probs[labels[idx+ 1 ], t+ 1 ] + beta[idx+ 1 , t+ 1 ])
156
+ end
157
+ if labels[idx] != blankLabel && idx != S- 1 && labels[idx] != labels[idx+ 2 ]
158
+ nextSum = log_plus_f (nextSum,
159
+ probs[labels[idx+ 2 ], t+ 1 ] + beta[idx + 2 , t+ 1 ])
160
+ end
161
+ if idx > 2 * t
162
+ beta[idx, t] = - Inf32
163
+ else
164
+ beta[idx, t] = nextSum
165
+ end
166
+ idx += blockDim (). x
167
+ end
168
+ sync_threads ()
169
+ idx = tid
170
+ while idx <= S
171
+ output[idx, t] = alphas[idx, t] + beta[idx, t]
172
+ idx += blockDim (). x
173
+ end
174
+ sync_threads ()
175
+ end
176
+ sync_threads ()
177
+
178
+ # Calculate accumulated alpha-beta products for each label class for
179
+ # each time step; used in calculating gradients
180
+ if tid == 1
181
+ for i= 1 : S
182
+ labelIdx = labels[i]
183
+ accum[labelIdx, t] = log_plus_f (accum[labelIdx, t], output[i, t])
184
+ end
185
+ end
186
+ sync_threads ()
187
+ idx = tid
188
+
189
+ # Calculate gradients
190
+ while idx <= size (grad, 1 )
191
+
192
+ # ∂L/∂a (where a is activation before logsoftmax)
193
+ grad[idx, t] = exp (probs[idx, t]) - exp (accum[idx, t] + loss)
194
+ idx += blockDim (). x
195
+ end
196
+ sync_threads ()
197
+ t -= 1
198
+ sync_threads ()
199
+ end
200
+ return nothing
201
+ end
202
+
203
+ function ctc_alpha (ŷ:: CuArray , y)
204
+ ŷ = logsoftmax (ŷ)
205
+ blank = size (ŷ, 1 )
206
+ ycu = cu (y)
207
+ z′ = CUDA. fill (blank, 2 * length (y) + 1 )
208
+ z′[eachindex (y) .* 2 ] .= ycu
209
+ T = size (ŷ, 2 )
210
+ U′ = 2 * length (y) + 1
211
+ alphas = CUDA. fill (log (zero (eltype (ŷ))), U′,T)
212
+ nRepeats = count_repeats (CUDA. adapt (Array, y))
213
+ nThreads = min (U′, MAX_THREADS)
214
+ @cuda blocks= 1 threads= nThreads compute_alpha_kernel (ŷ, length (y), T, nRepeats, ycu, z′, alphas, blank)
215
+ return (loss= - 1 * logsumexp (alphas[end - 1 : end ]), alpha= alphas, z′= z′, yhat= ŷ, nRepeats= nRepeats)
216
+ end
217
+
218
+ ctc_loss (ŷ:: CuArray , y) = ctc_alpha (ŷ:: CuArray , y). loss
219
+
220
+ function ∇ctc_loss (ŷ:: CuArray , y, out)
221
+ loss, alphas, z′, ŷ, nRepeats = out
222
+ U′, T = size (alphas)
223
+ blank = size (ŷ, 1 )
224
+ typed_zero = zero (eltype (ŷ))
225
+ betas = CUDA. fill (log (typed_zero), U′, T)
226
+ output = CUDA. fill (log (typed_zero), U′, T)
227
+ nThreads = min (U′, MAX_THREADS)
228
+ grads = CUDA. fill (log (typed_zero), size (ŷ))
229
+ accum = CUDA. fill (log (typed_zero), size (ŷ))
230
+ @cuda blocks= 1 threads= nThreads compute_beta_and_grad_kernel (ŷ, length (y), T, nRepeats, CuArray (z′), alphas, betas, output, accum, grads, blank, loss)
231
+ return grads
232
+ end
0 commit comments