Skip to content

Commit bc94a16

Browse files
committed
Fix indentation in ctc.jl
1 parent 6e5fb17 commit bc94a16

File tree

1 file changed

+25
-25
lines changed

1 file changed

+25
-25
lines changed

src/losses/ctc.jl

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,19 @@ function ctc_alpha(ŷ::AbstractArray, y)
4646
α[1,1] = ŷ[blank, 1]
4747
α[2,1] = ŷ[z′[2], 1]
4848
for t=2:T
49-
bound = max(1, U′ - 2(T - t) - 1)
49+
bound = max(1, U′ - 2(T - t) - 1)
5050
for u=bound:U′
51-
if u == 1
52-
α[u,t] = α[u, t-1]
53-
else
54-
α[u,t] = logaddexp(α[u, t-1], α[u-1, t-1])
55-
56-
# array bounds check and f(u) function from Eq. 7.9
57-
if u > 2 && !(z′[u] == blank || z′[u-2] == z′[u])
58-
α[u,t] = logaddexp(α[u,t], α[u-2,t-1])
59-
end
60-
end
61-
α[u,t] += ŷ[z′[u], t]
51+
if u == 1
52+
α[u,t] = α[u, t-1]
53+
else
54+
α[u,t] = logaddexp(α[u, t-1], α[u-1, t-1])
55+
56+
# array bounds check and f(u) function from Eq. 7.9
57+
if u > 2 && !(z′[u] == blank || z′[u-2] == z′[u])
58+
α[u,t] = logaddexp(α[u,t], α[u-2,t-1])
59+
end
60+
end
61+
α[u,t] += ŷ[z′[u], t]
6262
end
6363
end
6464
return (loss=-1 * logaddexp(α[end,T], α[end-1, T]), alpha=α, zprime=z′, logsoftyhat=ŷ)
@@ -80,18 +80,18 @@ function ∇ctc_loss(ŷ::AbstractArray, y, out)
8080

8181
# start at T-1 so that β(T, u) = log(0) for all u < U′ - 1
8282
for t=(T-1):-1:1
83-
bound = min(U′, 2t)
83+
bound = min(U′, 2t)
8484
for u=bound:-1:1
85-
if u == U′
86-
β[u,t] = ŷ[z′[u], t+1] + β[u, t+1]
87-
else
88-
β[u,t] = logaddexp(ŷ[z′[u], t+1] + β[u, t+1], ŷ[z′[u+1], t+1] + β[u+1,t+1])
85+
if u == U′
86+
β[u,t] = ŷ[z′[u], t+1] + β[u, t+1]
87+
else
88+
β[u,t] = logaddexp(ŷ[z′[u], t+1] + β[u, t+1], ŷ[z′[u+1], t+1] + β[u+1,t+1])
8989

90-
# array bounds check and g(u) function from Eq. 7.16
91-
if u+2 <= U′ && z′[u] != blank && z′[u] != z′[u+2]
92-
β[u,t] = logaddexp(β[u,t], ŷ[z′[u+2], t+1] + β[u+2, t+1])
93-
end
94-
end
90+
# array bounds check and g(u) function from Eq. 7.16
91+
if u+2 <= U′ && z′[u] != blank && z′[u] != z′[u+2]
92+
β[u,t] = logaddexp(β[u,t], ŷ[z′[u+2], t+1] + β[u+2, t+1])
93+
end
94+
end
9595
end
9696
end
9797

@@ -132,7 +132,7 @@ for mathematical details.
132132
ctc_loss(ŷ::AbstractArray, y) = ctc_alpha(ŷ, y).loss
133133

134134
@adjoint function ctc_loss(ŷ, y)
135-
out = ctc_alpha(ŷ, y)
136-
ctc_loss_pullback(Δ) =.* ∇ctc_loss(ŷ, y, out), nothing)
137-
return out.loss, ctc_loss_pullback
135+
out = ctc_alpha(ŷ, y)
136+
ctc_loss_pullback(Δ) =.* ∇ctc_loss(ŷ, y, out), nothing)
137+
return out.loss, ctc_loss_pullback
138138
end

0 commit comments

Comments
 (0)