Skip to content

Commit ee6172c

Browse files
committed
Added Tracker
1 parent 65505a1 commit ee6172c

File tree

1 file changed

+38
-3
lines changed

1 file changed

+38
-3
lines changed

test/testAD.jl

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Zygote, ForwardDiff
1+
using Zygote, ForwardDiff, Tracker
22

33
dims = [10,5]
44

@@ -13,11 +13,25 @@ testfunction(k,A) = sum(kernelmatrix(k,A))
1313

1414

1515
#For debugging
16+
17+
## Zygote
1618
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),vl)
17-
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),vl)
19+
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A),vl)
1820
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),l)
19-
ForwardDiff.gradient(x->testfunction(SquaredExponentialKernel(x[1]),A),[l])
21+
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A),l)
22+
23+
## Tracker
24+
Tracker.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),vl)
25+
Tracker.gradient(x->testfunction(SquaredExponentialKernel(x),A),vl)
26+
Tracker.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),l)
27+
Tracker.gradient(x->testfunction(SquaredExponentialKernel(x),A),l)
28+
2029

30+
## ForwardDiff
31+
ForwardDiff.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),vl) #
32+
ForwardDiff.gradient(x->testfunction(SquaredExponentialKernel(x),A),vl) #
33+
ForwardDiff.gradient(x->testfunction(SquaredExponentialKernel(x[1]),A,B),[l])
34+
ForwardDiff.gradient(x->testfunction(SquaredExponentialKernel(x[1]),A),[l])
2135
##Eventually store real results in file
2236

2337
@testset "Zygote Automatic Differentiation test" begin
@@ -40,6 +54,27 @@ ForwardDiff.gradient(x->testfunction(SquaredExponentialKernel(x[1]),A),[l])
4054
end
4155
end
4256

57+
@testset "Tracker AutomaticDifferentation test" begin
58+
@testset "ARD" begin
59+
for k in kernels
60+
@test Tracker.gradient(x->testfunction(k(x),A,B),vl)
61+
@test Tracker.gradient(x->testfunction(k(vl),x,B),A)
62+
@test Tracker.gradient(x->testfunction(k(x),A),vl)
63+
@test Tracker.gradient(x->testfunction(k(vl),x),A)
64+
end
65+
end
66+
@testset "ISO" begin
67+
for k in kernels
68+
@test Tracker.gradient(x->testfunction(k(x[1]),A,B),[l])
69+
@test Tracker.gradient(x->testfunction(k(l),x,B),A)
70+
@test Tracker.gradient(x->testfunction(k(x),A),[l])
71+
@test Tracker.gradient(x->testfunction(k(l[1]),x),A)
72+
73+
end
74+
end
75+
end
76+
77+
4378
@testset "ForwardDiff AutomaticDifferentation test" begin
4479
@testset "ARD" begin
4580
for k in kernels

0 commit comments

Comments
 (0)