|
| 1 | +# GPU implementation |
| 2 | + |
| 3 | +# a port of the GPU kernels from Baidu's C++ warp-ctc package, |
| 4 | +# which itself is Copyright 2015-2016 Baidu USA LLC |
| 5 | +# and available under the Apache 2.0 license |
| 6 | +# |
| 7 | +# Apache 2.0 license: https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# GitHub: https://github.com/baidu-research/warp-ctc/ |
| 9 | +# paper: https://arxiv.org/pdf/1512.02595.pdf |
| 10 | + |
| 11 | +using Flux |
| 12 | +using Statistics |
| 13 | +using CUDA |
| 14 | +using NNlib |
| 15 | + |
| 16 | +const MAX_THREADS = 256 |
| 17 | + |
| 18 | +function log_plus_f(p1, p2) |
| 19 | + isinf(p1) && return p2 |
| 20 | + isinf(p2) && return p1 |
| 21 | + if p1 < p2 |
| 22 | + p1, p2 = p2, p1 |
| 23 | + end |
| 24 | + return p1 + CUDA.log(1+CUDA.exp(p2 - p1)) |
| 25 | +end |
| 26 | + |
| 27 | +function count_repeats(A) |
| 28 | + repeats = 0 |
| 29 | + for (i,elem) in enumerate(A) |
| 30 | + if i > 1 && A[i] == A[i-1] |
| 31 | + repeats += 1 |
| 32 | + end |
| 33 | + end |
| 34 | + return repeats |
| 35 | +end |
| 36 | + |
| 37 | +function compute_alpha_kernel(probs, labelSize, uttLength, repeats, labelsWithoutBlanks, labelsWithBlanks, alpha, blankLabel) |
| 38 | + |
| 39 | + tid = threadIdx().x |
| 40 | + L = labelSize |
| 41 | + T = uttLength |
| 42 | + S = length(labelsWithBlanks) |
| 43 | + |
| 44 | + if L + repeats > T |
| 45 | + return nothing |
| 46 | + end |
| 47 | + labels = labelsWithBlanks |
| 48 | + |
| 49 | + # Corner-case checking |
| 50 | + start = (L + repeats <= T) ? 0 : 1 |
| 51 | + last = S > 1 ? 2 : 1 |
| 52 | + |
| 53 | + # Fill in first column (time step) |
| 54 | + i = tid |
| 55 | + while i <= last - start |
| 56 | + alpha[start+i, 1] = probs[labels[start+i], 1] |
| 57 | + i += blockDim().x |
| 58 | + end |
| 59 | + sync_threads() |
| 60 | + |
| 61 | + # Fill in coefficients for each time step |
| 62 | + for t=2:T |
| 63 | + # Corner-case checking |
| 64 | + if tid == 1 && !(1 < S - 2*(T-t) - 1) |
| 65 | + if start == 0 |
| 66 | + alpha[1, t] = alpha[1, t-1] + probs[blankLabel, t] |
| 67 | + elseif start == 1 |
| 68 | + alpha[1, t] = alpha[1, t-1] |
| 69 | + end |
| 70 | + end |
| 71 | + sync_threads() |
| 72 | + |
| 73 | + # Fill in coefficients for each label class in the target output sequence; |
| 74 | + # each thread will process the calculations for one class |
| 75 | + idx = tid+1 |
| 76 | + while idx <= S |
| 77 | + prevSum = log_plus_f(alpha[idx, t-1], alpha[idx-1, t-1]) |
| 78 | + if labels[idx] != blankLabel && idx != 2 && labels[idx] != labels[idx-2] |
| 79 | + prevSum = log_plus_f(prevSum, alpha[idx-2, t-1]) |
| 80 | + end |
| 81 | + if idx < S - 2*(T-t) - 1 |
| 82 | + alpha[idx, t] = -Inf32 |
| 83 | + else |
| 84 | + alpha[idx, t] = prevSum + probs[labels[idx], t] |
| 85 | + end |
| 86 | + idx += blockDim().x |
| 87 | + end |
| 88 | + sync_threads() |
| 89 | + end |
| 90 | + return nothing |
| 91 | +end |
| 92 | + |
| 93 | +function compute_beta_and_grad_kernel(probs, labelSize, uttLength, |
| 94 | + repeatsInLabel, labelsWithBlanks, |
| 95 | + alphas, beta, output, accum, |
| 96 | + grad, blankLabel, loss) |
| 97 | + |
| 98 | + tid = threadIdx().x |
| 99 | + L = labelSize |
| 100 | + T = uttLength |
| 101 | + S = 2*L + 1 |
| 102 | + repeats = repeatsInLabel |
| 103 | + labels = labelsWithBlanks |
| 104 | + |
| 105 | + if (L+repeats) > T |
| 106 | + return nothing |
| 107 | + end |
| 108 | + |
| 109 | + # Corner-case checking |
| 110 | + start = S > 1 ? S-2 : 0 |
| 111 | + last = L + repeats < T ? S : S-1 |
| 112 | + sync_threads() |
| 113 | + i = tid |
| 114 | + |
| 115 | + # Calculate coefficients for last column (time step) |
| 116 | + # then determine alpha and beta product |
| 117 | + while i <= last - start |
| 118 | + beta[i+start, T] = 0 |
| 119 | + output[i+start, T] = beta[i+start, T] + alphas[i+start, T] |
| 120 | + i += blockDim().x |
| 121 | + end |
| 122 | + sync_threads() |
| 123 | + |
| 124 | + # Fill in `accum` for last column (time step) |
| 125 | + if tid == 1 |
| 126 | + for i=1:S |
| 127 | + labelIdx = labels[i] |
| 128 | + accum[labelIdx, T] = log_plus_f(accum[labelIdx, T], output[i, T]) |
| 129 | + end |
| 130 | + end |
| 131 | + sync_threads() |
| 132 | + |
| 133 | + # Fill in `grad` for last column (time step) |
| 134 | + idx = tid |
| 135 | + while idx <= size(grad, 1) |
| 136 | + s = -Inf32 |
| 137 | + for i=1:S |
| 138 | + s = log_plus_f(s, output[i, T]) |
| 139 | + end |
| 140 | + |
| 141 | + # ∂L/∂a (where a is activation before logsoftmax) |
| 142 | + grad[idx, T] = CUDA.exp(probs[idx, T]) - CUDA.exp(accum[idx, T] - s) |
| 143 | + idx += blockDim().x |
| 144 | + end |
| 145 | + sync_threads() |
| 146 | + |
| 147 | + # Fill in the rest of the coefficients |
| 148 | + t = T-1 |
| 149 | + while t >= 1 |
| 150 | + if t < T |
| 151 | + idx = tid |
| 152 | + while idx <= S |
| 153 | + nextSum = probs[labels[idx], t+1] + beta[idx, t+1] |
| 154 | + if idx < S |
| 155 | + nextSum = log_plus_f(nextSum, |
| 156 | + probs[labels[idx+1], t+1] + beta[idx+1, t+1]) |
| 157 | + end |
| 158 | + if labels[idx] != blankLabel && idx != S-1 && labels[idx] != labels[idx+2] |
| 159 | + nextSum = log_plus_f(nextSum, |
| 160 | + probs[labels[idx+2], t+1] + beta[idx + 2, t+1]) |
| 161 | + end |
| 162 | + if idx > 2*t |
| 163 | + beta[idx, t] = -Inf32 |
| 164 | + else |
| 165 | + beta[idx, t] = nextSum |
| 166 | + end |
| 167 | + idx += blockDim().x |
| 168 | + end |
| 169 | + sync_threads() |
| 170 | + idx = tid |
| 171 | + while idx <= S |
| 172 | + output[idx, t] = alphas[idx, t] + beta[idx, t] |
| 173 | + idx += blockDim().x |
| 174 | + end |
| 175 | + sync_threads() |
| 176 | + end |
| 177 | + sync_threads() |
| 178 | + |
| 179 | + # Calculate accumulated alpha-beta products for each label class for |
| 180 | + # each time step; used in calculating gradients |
| 181 | + if tid == 1 |
| 182 | + for i=1:S |
| 183 | + labelIdx = labels[i] |
| 184 | + accum[labelIdx, t] = log_plus_f(accum[labelIdx, t], output[i, t]) |
| 185 | + end |
| 186 | + end |
| 187 | + sync_threads() |
| 188 | + idx = tid |
| 189 | + |
| 190 | + # Calculate gradients |
| 191 | + while idx <= size(grad, 1) |
| 192 | + |
| 193 | + # ∂L/∂a (where a is activation before logsoftmax) |
| 194 | + grad[idx, t] = CUDA.exp(probs[idx, t]) - CUDA.exp(accum[idx, t] + loss) |
| 195 | + idx += blockDim().x |
| 196 | + end |
| 197 | + sync_threads() |
| 198 | + t -= 1 |
| 199 | + sync_threads() |
| 200 | + end |
| 201 | + return nothing |
| 202 | +end |
| 203 | + |
| 204 | +function ctc_alpha(ŷ::CuArray, y) |
| 205 | + ŷ = logsoftmax(ŷ) |
| 206 | + blank = size(ŷ, 1) |
| 207 | + z′ = fill(blank, 2 * length(y) + 1) |
| 208 | + z′[eachindex(y) .* 2] = y |
| 209 | + T = size(ŷ, 2) |
| 210 | + U′ = 2*length(y) + 1 |
| 211 | + alphas = CUDA.fill(log(zero(ŷ[1])), U′,T) |
| 212 | + nRepeats = count_repeats(y) |
| 213 | + nThreads = min(U′, MAX_THREADS) |
| 214 | + @cuda blocks=1 threads=nThreads compute_alpha_kernel(ŷ, length(y), T, nRepeats, CuArray(y), CuArray(z′), alphas, blank) |
| 215 | + return (loss=-1 * logsumexp(alphas[end-1:end]), alpha=alphas, z′=z′, yhat=ŷ, nRepeats=nRepeats) |
| 216 | +end |
| 217 | + |
| 218 | +ctc_loss(ŷ::CuArray, y) = ctc_alpha(ŷ::CuArray, y).loss |
| 219 | + |
| 220 | +function ∇ctc_loss(ŷ::CuArray, y, out) |
| 221 | + loss, alphas, z′, ŷ, nRepeats = out |
| 222 | + U′, T = size(alphas) |
| 223 | + blank = size(ŷ, 1) |
| 224 | + typed_zero = zero(first(ŷ)) |
| 225 | + betas = CUDA.fill(log(typed_zero), U′, T) |
| 226 | + output = CUDA.fill(log(typed_zero), U′, T) |
| 227 | + nThreads = min(U′, MAX_THREADS) |
| 228 | + grads = CUDA.fill(log(typed_zero), size(ŷ)) |
| 229 | + accum = CUDA.fill(log(typed_zero), size(ŷ)) |
| 230 | + @cuda blocks=1 threads=nThreads compute_beta_and_grad_kernel(ŷ, length(y), T, nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank, loss) |
| 231 | + return grads |
| 232 | +end |
0 commit comments