1
+ using KernelFunctions
1
2
using Zygote, ForwardDiff, Tracker
3
+ using Test
2
4
3
5
dims = [10 ,5 ]
4
6
@@ -15,10 +17,10 @@ testfunction(k,A) = sum(kernelmatrix(k,A))
15
17
# For debugging
16
18
17
19
# # Zygote
18
- Zygote. gradient (x-> testfunction (SquaredExponentialKernel (x),A,B),vl)
19
- Zygote. gradient (x-> testfunction (SquaredExponentialKernel (x),A),vl)
20
- Zygote. gradient (x-> testfunction (SquaredExponentialKernel (x),A,B),l)
21
- Zygote. gradient (x-> testfunction (SquaredExponentialKernel (x),A),l)
20
+ # Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),vl)
21
+ # Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A),vl)
22
+ # Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),l)
23
+ # Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A),l)
22
24
23
25
# # Tracker
24
26
Tracker. gradient (x-> testfunction (SquaredExponentialKernel (x),A,B),vl)
@@ -37,18 +39,18 @@ ForwardDiff.gradient(x->testfunction(SquaredExponentialKernel(x[1]),A),[l])
37
39
@testset " Zygote Automatic Differentiation test" begin
38
40
@testset " ARD" begin
39
41
for k in kernels
40
- @test Zygote. gradient (x-> testfunction (k (x),A,B),vl)
41
- @test Zygote. gradient (x-> testfunction (k (vl),x,B),A)
42
- @test Zygote. gradient (x-> testfunction (k (x),A),vl)
43
- @test Zygote. gradient (x-> testfunction (k (vl),x),A)
42
+ @test_broken Zygote. gradient (x-> testfunction (k (x),A,B),vl)
43
+ @test_broken Zygote. gradient (x-> testfunction (k (vl),x,B),A)
44
+ @test_broken Zygote. gradient (x-> testfunction (k (x),A),vl)
45
+ @test_broken Zygote. gradient (x-> testfunction (k (vl),x),A)
44
46
end
45
47
end
46
48
@testset " ISO" begin
47
49
for k in kernels
48
- @test Zygote. gradient (x-> testfunction (k (x),A,B),l)
49
- @test Zygote. gradient (x-> testfunction (k (l),x,B),A)
50
- @test Zygote. gradient (x-> testfunction (k (x),A),l)
51
- @test Zygote. gradient (x-> testfunction (k (l),x),A)
50
+ @test_broken Zygote. gradient (x-> testfunction (k (x),A,B),l)
51
+ @test_broken Zygote. gradient (x-> testfunction (k (l),x,B),A)
52
+ @test_broken Zygote. gradient (x-> testfunction (k (x),A),l)
53
+ @test_broken Zygote. gradient (x-> testfunction (k (l),x),A)
52
54
53
55
end
54
56
end
57
59
@testset " Tracker AutomaticDifferentation test" begin
58
60
@testset " ARD" begin
59
61
for k in kernels
60
- @test Tracker. gradient (x-> testfunction (k (x),A,B),vl)
61
- @test Tracker. gradient (x-> testfunction (k (vl),x,B),A)
62
- @test Tracker. gradient (x-> testfunction (k (x),A),vl)
63
- @test Tracker. gradient (x-> testfunction (k (vl),x),A)
62
+ @test_nowarn Tracker. gradient (x-> testfunction (k (x),A,B),vl)
63
+ @test_broken Tracker. gradient (x-> testfunction (k (vl),x,B),A)
64
+ @test_nowarn Tracker. gradient (x-> testfunction (k (x),A),vl)
65
+ @test_broken Tracker. gradient (x-> testfunction (k (vl),x),A)
64
66
end
65
67
end
66
68
@testset " ISO" begin
67
69
for k in kernels
68
- @test Tracker. gradient (x-> testfunction (k (x[1 ]),A,B),[l])
69
- @test Tracker. gradient (x-> testfunction (k (l),x,B),A)
70
- @test Tracker. gradient (x-> testfunction (k (x),A),[l])
71
- @test Tracker. gradient (x-> testfunction (k (l[ 1 ] ),x),A)
70
+ @test_nowarn Tracker. gradient (x-> testfunction (k (x[1 ]),A,B),[l])
71
+ @test_broken Tracker. gradient (x-> testfunction (k (l),x,B),A)
72
+ @test_nowarn Tracker. gradient (x-> testfunction (k (x[ 1 ] ),A),[l])
73
+ @test_broken Tracker. gradient (x-> testfunction (k (l),x),A)
72
74
73
75
end
74
76
end
78
80
@testset " ForwardDiff AutomaticDifferentation test" begin
79
81
@testset " ARD" begin
80
82
for k in kernels
81
- @test ForwardDiff. gradient (x-> testfunction (k (x),A,B),vl)
82
- @test ForwardDiff. gradient (x-> testfunction (k (vl),x,B),A)
83
- @test ForwardDiff. gradient (x-> testfunction (k (x),A),vl)
84
- @test ForwardDiff. gradient (x-> testfunction (k (vl),x),A)
83
+ @test_nowarn ForwardDiff. gradient (x-> testfunction (k (x),A,B),vl)
84
+ @test_nowarn ForwardDiff. gradient (x-> testfunction (k (vl),x,B),A)
85
+ @test_nowarn ForwardDiff. gradient (x-> testfunction (k (x),A),vl)
86
+ @test_nowarn ForwardDiff. gradient (x-> testfunction (k (vl),x),A)
85
87
end
86
88
end
87
89
@testset " ISO" begin
88
90
for k in kernels
89
- @test ForwardDiff. gradient (x-> testfunction (k (x[1 ]),A,B),[l])
90
- @test ForwardDiff. gradient (x-> testfunction (k (l),x,B),A)
91
- @test ForwardDiff. gradient (x-> testfunction (k (x),A),[l])
92
- @test ForwardDiff. gradient (x-> testfunction (k (l[1 ]),x),A)
93
-
91
+ @test_nowarn ForwardDiff. gradient (x-> testfunction (k (x[1 ]),A,B),[l])
92
+ @test_nowarn ForwardDiff. gradient (x-> testfunction (k (l),x,B),A)
93
+ @test_nowarn ForwardDiff. gradient (x-> testfunction (k (x[1 ]),A),[l])
94
+ @test_nowarn ForwardDiff. gradient (x-> testfunction (k (l[1 ]),x),A)
94
95
end
95
96
end
96
97
end
0 commit comments