Skip to content

Commit 6e5fb17

Browse files
committed
Remove F in ctc tests; update ctc-gpu test syntax
1 parent e1e8cc8 commit 6e5fb17

File tree

2 files changed

+9
-35
lines changed

2 files changed

+9
-35
lines changed

test/ctc-gpu.jl

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,32 +23,19 @@ function ctc_ngradient(x, y)
2323
return grads
2424
end
2525

26-
function F(A, blank)
27-
prev = A[1]
28-
z = [prev]
29-
for curr in A[2:end]
30-
if curr != prev
31-
push!(z, curr)
32-
end
33-
prev = curr
34-
end
35-
filter!(x -> x != blank, z)
36-
return z
37-
end
38-
3926
@testset "ctc-gpu" begin
4027
x = rand(10, 50)
41-
y = F(rand(1:9, 30), 10)
28+
y = rand(1:9, 30)
4229
x_cu = CuArray(x)
4330
g1 = gradient(ctc_loss, x_cu, y)[1]
4431
g1 = g1 |> collect
4532
g2 = ctc_ngradient(x, y)
46-
@test all(isapprox.(g1, g2, rtol=1e-5, atol=1e-5))
33+
@test g1 g2 rtol=1e-5 atol=1e-5
4734

4835
# test that GPU loss matches CPU implementation
4936
l1 = ctc_loss(x_cu, y)
5037
l2 = ctc_loss(x, y)
51-
@test all(isapprox.(l1, l2, rtol=1e-5, atol=1e-5))
38+
@test l1 l2
5239

5340
# tests using hand-calculated values
5441
x_cu = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.] |> CuArray
@@ -57,13 +44,13 @@ end
5744

5845
g = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457]
5946
ghat = gradient(ctc_loss, x_cu, y)[1] |> collect
60-
@test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5))
47+
@test g ghat rtol=1e-5 atol=1e-5
6148

6249
x_cu = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.] |> CuArray
6350
y = [1, 2] |> CuArray
6451
@test ctc_loss(x_cu, y) 8.02519869363453
6552

6653
g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07]
6754
ghat = gradient(ctc_loss, x_cu, y)[1] |> collect
68-
@test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5))
55+
@test g ghat rtol=1e-5 atol=1e-5
6956
end

test/ctc.jl

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,12 @@ function ctc_ngradient(x, y)
2222
return grads
2323
end
2424

25-
function F(A, blank)
26-
prev = A[1]
27-
z = [prev]
28-
for curr in A[2:end]
29-
if curr != prev
30-
push!(z, curr)
31-
end
32-
prev = curr
33-
end
34-
filter!(x -> x != blank, z)
35-
return z
36-
end
37-
3825
@testset "ctc_loss" begin
3926
x = rand(10, 50)
40-
y = F(rand(1:9, 30), 10)
27+
y = rand(1:9, 30)
4128
g1 = gradient(ctc_loss, x, y)[1]
4229
g2 = ctc_ngradient(x, y)
43-
@test g1 g2 rtol=1e-5 atol=1e-5
30+
@test g1 g2 rtol=1e-5 atol=1e-5
4431

4532
# tests using hand-calculated values
4633
x = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.]
@@ -49,13 +36,13 @@ end
4936

5037
g = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457]
5138
ghat = gradient(ctc_loss, x, y)[1]
52-
@test g ghat rtol=1e-5 atol=1e-5
39+
@test g ghat rtol=1e-5 atol=1e-5
5340

5441
x = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.]
5542
y = [1, 2]
5643
@test ctc_loss(x, y) 8.02519869363453
5744

5845
g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07]
5946
ghat = gradient(ctc_loss, x, y)[1]
60-
@test g ghat rtol=1e-5 atol=1e-5
47+
@test g ghat rtol=1e-5 atol=1e-5
6148
end

0 commit comments

Comments
 (0)