You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
g = [-2.29294774655333e-06-0.9996626572788621.75500863563993e-060.00669284889063; 0.0179859149696960.999662657278861-1.9907078755387e-06-0.006693150917307; -0.01798362202195-2.52019580677916e-202.35699239251042e-073.02026677058789e-07]
Copy file name to clipboardExpand all lines: test/ctc.jl
+4-17Lines changed: 4 additions & 17 deletions
Original file line number
Diff line number
Diff line change
@@ -22,25 +22,12 @@ function ctc_ngradient(x, y)
22
22
return grads
23
23
end
24
24
25
-
functionF(A, blank)
26
-
prev = A[1]
27
-
z = [prev]
28
-
for curr in A[2:end]
29
-
if curr != prev
30
-
push!(z, curr)
31
-
end
32
-
prev = curr
33
-
end
34
-
filter!(x -> x != blank, z)
35
-
return z
36
-
end
37
-
38
25
@testset"ctc_loss"begin
39
26
x =rand(10, 50)
40
-
y =F(rand(1:9, 30), 10)
27
+
y =rand(1:9, 30)
41
28
g1 =gradient(ctc_loss, x, y)[1]
42
29
g2 =ctc_ngradient(x, y)
43
-
@test g1 ≈ g2 rtol=1e-5 atol=1e-5
30
+
@test g1 ≈ g2 rtol=1e-5 atol=1e-5
44
31
45
32
# tests using hand-calculated values
46
33
x = [1.2.3.; 2.1.1.; 3.3.2.]
@@ -49,13 +36,13 @@ end
49
36
50
37
g = [-0.317671-0.4277290.665241; 0.244728-0.0196172-0.829811; 0.07294220.4473460.16457]
51
38
ghat =gradient(ctc_loss, x, y)[1]
52
-
@test g ≈ ghat rtol=1e-5 atol=1e-5
39
+
@test g ≈ ghat rtol=1e-5 atol=1e-5
53
40
54
41
x = [-3.12.8.15.; 4.20.-2.20.; 8.-33.6.5.]
55
42
y = [1, 2]
56
43
@testctc_loss(x, y) ≈8.02519869363453
57
44
58
45
g = [-2.29294774655333e-06-0.9996626572788621.75500863563993e-060.00669284889063; 0.0179859149696960.999662657278861-1.9907078755387e-06-0.006693150917307; -0.01798362202195-2.52019580677916e-202.35699239251042e-073.02026677058789e-07]
0 commit comments