Skip to content

Commit 6b5ba4d

Browse files
committed
Corrected Tests Zygote Adjoints
1 parent ffefd1f commit 6b5ba4d

File tree

1 file changed

+30
-10
lines changed

1 file changed

+30
-10
lines changed

test/zygote_adjoints.jl

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,37 @@
55
y = rand(rng, 5)
66
r = rand(rng, 5)
77

8-
gzeucl = gradient(Val(:Zygote), xy -> evaluate(Euclidean(), xy[1], xy[2]), [x,y])
9-
gzsqeucl = gradient(Val(:Zygote), xy -> evaluate(SqEuclidean(), xy[1], xy[2]), [x,y])
10-
gzdotprod = gradient(Val(:Zygote), xy -> evaluate(KernelFunctions.DotProduct(), xy[1], xy[2]), [x,y])
11-
gzdelta = gradient(Val(:Zygote), xy -> evaluate(KernelFunctions.Delta(), xy[1], xy[2]), [x,y])
12-
gzsinus = gradient(Val(:Zygote), xy -> evaluate(KernelFunctions.Sinus(r), xy[1], xy[2]), [x,y])
8+
gzeucl = gradient(:Zygote, [x,y]) do xy
9+
evaluate(Euclidean(), xy[1], xy[2])
10+
end
11+
gzsqeucl = gradient(:Zygote, [x,y]) do xy
12+
evaluate(SqEuclidean(), xy[1], xy[2])
13+
end
14+
gzdotprod = gradient(:Zygote, [x,y]) do xy
15+
evaluate(KernelFunctions.DotProduct(), xy[1], xy[2])
16+
end
17+
gzdelta = gradient(:Zygote, [x,y]) do xy
18+
evaluate(KernelFunctions.Delta(), xy[1], xy[2])
19+
end
20+
gzsinus = gradient(:Zygote, [x,y]) do xy
21+
evaluate(KernelFunctions.Sinus(r), xy[1], xy[2])
22+
end
1323

14-
gfeucl = gradient(Val(:FiniteDiff), xy -> evaluate(Euclidean(), xy[1], xy[2]), [x,y])
15-
gfsqeucl = gradient(Val(:FiniteDiff), xy -> evaluate(SqEuclidean(), xy[1], xy[2]), [x,y])
16-
gfdotprod = gradient(Val(:FiniteDiff), xy -> evaluate(KernelFunctions.DotProduct(), xy[1], xy[2]), [x,y])
17-
gfdelta = gradient(Val(:FiniteDiff), xy -> evaluate(KernelFunctions.Delta(), xy[1], xy[2]), [x,y])
18-
gfsinus = gradient(Val(:FiniteDiff), xy -> evaluate(KernelFunctions.Sinus(r), xy[1], xy[2]), [x,y])
24+
gfeucl = gradient(:FiniteDiff, [x,y]) do xy
25+
evaluate(Euclidean(), xy[1], xy[2])
26+
end
27+
gfsqeucl = gradient(:FiniteDiff, [x,y]) do xy
28+
evaluate(SqEuclidean(), xy[1], xy[2])
29+
end
30+
gfdotprod = gradient(:FiniteDiff, [x,y]) do xy
31+
evaluate(KernelFunctions.DotProduct(), xy[1], xy[2])
32+
end
33+
gfdelta = gradient(:FiniteDiff, [x,y]) do xy
34+
evaluate(KernelFunctions.Delta(), xy[1], xy[2])
35+
end
36+
gfsinus = gradient(:FiniteDiff, [x,y]) do xy
37+
evaluate(KernelFunctions.Sinus(r), xy[1], xy[2])
38+
end
1939

2040

2141
@test all(gzeucl .≈ gfeucl)

0 commit comments

Comments
 (0)