@@ -46,19 +46,19 @@ function ctc_alpha(ŷ::AbstractArray, y)
46
46
α[1 ,1 ] = ŷ[blank, 1 ]
47
47
α[2 ,1 ] = ŷ[z′[2 ], 1 ]
48
48
for t= 2 : T
49
- bound = max (1 , U′ - 2 (T - t) - 1 )
49
+ bound = max (1 , U′ - 2 (T - t) - 1 )
50
50
for u= bound: U′
51
- if u == 1
52
- α[u,t] = α[u, t- 1 ]
53
- else
54
- α[u,t] = logaddexp (α[u, t- 1 ], α[u- 1 , t- 1 ])
55
-
56
- # array bounds check and f(u) function from Eq. 7.9
57
- if u > 2 && ! (z′[u] == blank || z′[u- 2 ] == z′[u])
58
- α[u,t] = logaddexp (α[u,t], α[u- 2 ,t- 1 ])
59
- end
60
- end
61
- α[u,t] += ŷ[z′[u], t]
51
+ if u == 1
52
+ α[u,t] = α[u, t- 1 ]
53
+ else
54
+ α[u,t] = logaddexp (α[u, t- 1 ], α[u- 1 , t- 1 ])
55
+
56
+ # array bounds check and f(u) function from Eq. 7.9
57
+ if u > 2 && ! (z′[u] == blank || z′[u- 2 ] == z′[u])
58
+ α[u,t] = logaddexp (α[u,t], α[u- 2 ,t- 1 ])
59
+ end
60
+ end
61
+ α[u,t] += ŷ[z′[u], t]
62
62
end
63
63
end
64
64
return (loss= - 1 * logaddexp (α[end ,T], α[end - 1 , T]), alpha= α, zprime= z′, logsoftyhat= ŷ)
@@ -80,18 +80,18 @@ function ∇ctc_loss(ŷ::AbstractArray, y, out)
80
80
81
81
# start at T-1 so that β(T, u) = log(0) for all u < U′ - 1
82
82
for t= (T- 1 ): - 1 : 1
83
- bound = min (U′, 2 t)
83
+ bound = min (U′, 2 t)
84
84
for u= bound: - 1 : 1
85
- if u == U′
86
- β[u,t] = ŷ[z′[u], t+ 1 ] + β[u, t+ 1 ]
87
- else
88
- β[u,t] = logaddexp (ŷ[z′[u], t+ 1 ] + β[u, t+ 1 ], ŷ[z′[u+ 1 ], t+ 1 ] + β[u+ 1 ,t+ 1 ])
85
+ if u == U′
86
+ β[u,t] = ŷ[z′[u], t+ 1 ] + β[u, t+ 1 ]
87
+ else
88
+ β[u,t] = logaddexp (ŷ[z′[u], t+ 1 ] + β[u, t+ 1 ], ŷ[z′[u+ 1 ], t+ 1 ] + β[u+ 1 ,t+ 1 ])
89
89
90
- # array bounds check and g(u) function from Eq. 7.16
91
- if u+ 2 <= U′ && z′[u] != blank && z′[u] != z′[u+ 2 ]
92
- β[u,t] = logaddexp (β[u,t], ŷ[z′[u+ 2 ], t+ 1 ] + β[u+ 2 , t+ 1 ])
93
- end
94
- end
90
+ # array bounds check and g(u) function from Eq. 7.16
91
+ if u+ 2 <= U′ && z′[u] != blank && z′[u] != z′[u+ 2 ]
92
+ β[u,t] = logaddexp (β[u,t], ŷ[z′[u+ 2 ], t+ 1 ] + β[u+ 2 , t+ 1 ])
93
+ end
94
+ end
95
95
end
96
96
end
97
97
@@ -132,7 +132,7 @@ for mathematical details.
132
132
ctc_loss (ŷ:: AbstractArray , y) = ctc_alpha (ŷ, y). loss
133
133
134
134
@adjoint function ctc_loss (ŷ, y)
135
- out = ctc_alpha (ŷ, y)
136
- ctc_loss_pullback (Δ) = (Δ .* ∇ctc_loss (ŷ, y, out), nothing )
137
- return out. loss, ctc_loss_pullback
135
+ out = ctc_alpha (ŷ, y)
136
+ ctc_loss_pullback (Δ) = (Δ .* ∇ctc_loss (ŷ, y, out), nothing )
137
+ return out. loss, ctc_loss_pullback
138
138
end
0 commit comments