Skip to content

Commit ef39ba7

Browse files
committed
Corrected tests
1 parent a2ca5a6 commit ef39ba7

File tree

2 files changed

+29
-25
lines changed

2 files changed

+29
-25
lines changed

test/testAD.jl

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@ testfunction(k,A) = sum(kernelmatrix(k,A))
1717
#For debugging
1818

1919
## Zygote
20-
# Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),vl)
21-
# Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A),vl)
22-
# Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),l)
23-
# Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A),l)
24-
20+
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A),vl)
21+
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),vl)
22+
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),l)
23+
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A),l)
2524
## Tracker
26-
Tracker.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),vl)
25+
Tracker.gradient(x->testfunction(SquaredExponentialKernel(vl),x,B),A)
26+
Tracker.gradient(x->testfunction(SquaredExponentialKernel(l),x[:,:]),A)
27+
# Tracker.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),vl)
2728
Tracker.gradient(x->testfunction(SquaredExponentialKernel(x),A),vl)
2829
Tracker.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),l)
2930
Tracker.gradient(x->testfunction(SquaredExponentialKernel(x),A),l)
@@ -56,42 +57,42 @@ ForwardDiff.gradient(x->testfunction(SquaredExponentialKernel(x[1]),A),[l])
5657
end
5758
end
5859

59-
@testset "Tracker AutomaticDifferentation test" begin
60+
@testset "ForwardDiff AutomaticDifferentation test" begin
6061
@testset "ARD" begin
6162
for k in kernels
62-
@test_nowarn Tracker.gradient(x->testfunction(k(x),A,B),vl)
63-
@test_broken Tracker.gradient(x->testfunction(k(vl),x,B),A)
64-
@test_nowarn Tracker.gradient(x->testfunction(k(x),A),vl)
65-
@test_broken Tracker.gradient(x->testfunction(k(vl),x),A)
63+
@test_nowarn ForwardDiff.gradient(x->testfunction(k(x),A,B),vl)
64+
@test_nowarn ForwardDiff.gradient(x->testfunction(k(vl),x,B),A)
65+
@test_nowarn ForwardDiff.gradient(x->testfunction(k(x),A),vl)
66+
@test_nowarn ForwardDiff.gradient(x->testfunction(k(vl),x),A)
6667
end
6768
end
6869
@testset "ISO" begin
6970
for k in kernels
70-
@test_nowarn Tracker.gradient(x->testfunction(k(x[1]),A,B),[l])
71-
@test_broken Tracker.gradient(x->testfunction(k(l),x,B),A)
72-
@test_nowarn Tracker.gradient(x->testfunction(k(x[1]),A),[l])
73-
@test_broken Tracker.gradient(x->testfunction(k(l),x),A)
74-
71+
@test_nowarn ForwardDiff.gradient(x->testfunction(k(x[1]),A,B),[l])
72+
@test_nowarn ForwardDiff.gradient(x->testfunction(k(l),x,B),A)
73+
@test_nowarn ForwardDiff.gradient(x->testfunction(k(x[1]),A),[l])
74+
@test_nowarn ForwardDiff.gradient(x->testfunction(k(l[1]),x),A)
7575
end
7676
end
7777
end
7878

7979

80-
@testset "ForwardDiff AutomaticDifferentation test" begin
80+
@testset "Tracker AutomaticDifferentation test" begin
8181
@testset "ARD" begin
8282
for k in kernels
83-
@test_nowarn ForwardDiff.gradient(x->testfunction(k(x),A,B),vl)
84-
@test_nowarn ForwardDiff.gradient(x->testfunction(k(vl),x,B),A)
85-
@test_nowarn ForwardDiff.gradient(x->testfunction(k(x),A),vl)
86-
@test_nowarn ForwardDiff.gradient(x->testfunction(k(vl),x),A)
83+
@test all(Tracker.gradient(x->testfunction(k(x),A,B),vl)[1] .≈ ForwardDiff.gradient(x->testfunction(k(x),A,B),vl))
84+
@test_broken all(Tracker.gradient(x->testfunction(k(vl),x,B),A)[1] .≈ ForwardDiff.gradient(x->testfunction(k(vl),x,B),A))
85+
@test all(Tracker.gradient(x->testfunction(k(x),A),vl)[1] .≈ ForwardDiff.gradient(x->testfunction(k(x),A),vl))
86+
@test_broken all.(Tracker.gradient(x->testfunction(k(vl),x),A) .≈ ForwardDiff.gradient(x->testfunction(k(vl),x),A))
8787
end
8888
end
8989
@testset "ISO" begin
9090
for k in kernels
91-
@test_nowarn ForwardDiff.gradient(x->testfunction(k(x[1]),A,B),[l])
92-
@test_nowarn ForwardDiff.gradient(x->testfunction(k(l),x,B),A)
93-
@test_nowarn ForwardDiff.gradient(x->testfunction(k(x[1]),A),[l])
94-
@test_nowarn ForwardDiff.gradient(x->testfunction(k(l[1]),x),A)
91+
@test_nowarn Tracker.gradient(x->testfunction(k(x[1]),A,B),[l])
92+
@test_broken Tracker.gradient(x->testfunction(k(l),x,B),A)
93+
@test_nowarn Tracker.gradient(x->testfunction(k(x[1]),A),[l])
94+
@test_broken Tracker.gradient(x->testfunction(k(l),x),A)
95+
9596
end
9697
end
9798
end

test/types.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
## Testing output types from kernel
2+
3+
#TODO

0 commit comments

Comments
 (0)