Skip to content

Commit f1000b3

Browse files
committed
Fixed tests and added adjoint tests
1 parent 7f52242 commit f1000b3

File tree

5 files changed

+21
-10
lines changed

5 files changed

+21
-10
lines changed

test/basekernels/matern.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
@test metric(MaternKernel()) == Euclidean()
1515
@test metric(MaternKernel=2.0)) == Euclidean()
1616
@test repr(k) == "Matern Kernel (ν = $(ν))"
17-
test_ADs(x->MaternKernel(nu=first(x)),[ν])
17+
# test_ADs(x->MaternKernel(nu=first(x)),[ν])
1818
@test_broken "All fails (because of logabsgamma for ForwardDiff and ReverseDiff and because of nu for Zygote)"
1919
end
2020
@testset "Matern32Kernel" begin

test/kernels/kernelsum.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,6 @@
5353
@test kerneldiagmatrix!(tmp_diag, k, x) kerneldiagmatrix(k, x)
5454
end
5555
end
56-
test_ADs(x->KernelSum([SqExponentialKernel(),LinearKernel(c= x[1])], x[2:3]), rand(3))#, ADs = [:ForwardDiff, :ReverseDiff])
56+
test_ADs(x->KernelSum([SqExponentialKernel(),LinearKernel(c= x[1])], x[2:3]), rand(3), ADs = [:ForwardDiff, :ReverseDiff])
5757
@test_broken "Zygote failing because of mutating array"
5858
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using PDMats
66
using Random
77
using SpecialFunctions
88
using Test
9+
using Flux: params
910
import Zygote, ForwardDiff, ReverseDiff, FiniteDifferences
1011

1112
using KernelFunctions: metric, kappa, ColVecs, RowVecs

test/utils_AD.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@ FDM = FiniteDifferences.central_fdm(5, 1)
44
function gradient(::Val{:Zygote}, f::Function, args)
55
g = first(Zygote.gradient(f, args))
66
if isnothing(g)
7-
return zeros(size(args)) # To respect the same output as other ADs
7+
if args isa AbstractArray{<:Real}
8+
return zeros(size(args)) # To respect the same output as other ADs
9+
else
10+
return zeros.(size.(args))
11+
end
812
else
913
return g
1014
end

test/zygote_adjoints.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,24 @@
33
rng = MersenneTwister(123456)
44
x = rand(rng, 5)
55
y = rand(rng, 5)
6+
r = rand(rng, 5)
67

7-
gzeucl = first(Zygote.gradient(xy->evaluate(Euclidean(),xy[1],xy[2]),[x,y]))
8-
gzsqeucl = first(Zygote.gradient(xy->evaluate(SqEuclidean(),xy[1],xy[2]),[x,y]))
9-
gzdotprod = first(Zygote.gradient(xy->evaluate(KernelFunctions.DotProduct(),xy[1],xy[2]),[x,y]))
8+
gzeucl = gradient(Val(:Zygote), xy -> evaluate(Euclidean(), xy[1], xy[2]), [x,y])
9+
gzsqeucl = gradient(Val(:Zygote), xy -> evaluate(SqEuclidean(), xy[1], xy[2]), [x,y])
10+
gzdotprod = gradient(Val(:Zygote), xy -> evaluate(KernelFunctions.DotProduct(), xy[1], xy[2]), [x,y])
11+
gzdelta = gradient(Val(:Zygote), xy -> evaluate(KernelFunctions.Delta(), xy[1], xy[2]), [x,y])
12+
gzsinus = gradient(Val(:Zygote), xy -> evaluate(KernelFunctions.Sinus(r), xy[1], xy[2]), [x,y])
1013

11-
FDM = central_fdm(5,1)
14+
gfeucl = gradient(Val(:FiniteDiff), xy -> evaluate(Euclidean(), xy[1], xy[2]), [x,y])
15+
gfsqeucl = gradient(Val(:FiniteDiff), xy -> evaluate(SqEuclidean(), xy[1], xy[2]), [x,y])
16+
gfdotprod = gradient(Val(:FiniteDiff), xy -> evaluate(KernelFunctions.DotProduct(), xy[1], xy[2]), [x,y])
17+
gfdelta = gradient(Val(:FiniteDiff), xy -> evaluate(KernelFunctions.Delta(), xy[1], xy[2]), [x,y])
18+
gfsinus = gradient(Val(:FiniteDiff), xy -> evaluate(KernelFunctions.Sinus(r), xy[1], xy[2]), [x,y])
1219

13-
gfeucl = collect(first(FiniteDifferences.grad(FDM,xy->evaluate(Euclidean(),xy[1],xy[2]),(x,y))))
14-
gfsqeucl = collect(first(FiniteDifferences.grad(FDM,xy->evaluate(SqEuclidean(),xy[1],xy[2]),(x,y))))
15-
gfdotprod =collect(first(FiniteDifferences.grad(FDM,xy->evaluate(KernelFunctions.DotProduct(),xy[1],xy[2]),(x,y))))
1620

1721
@test all(gzeucl .≈ gfeucl)
1822
@test all(gzsqeucl .≈ gfsqeucl)
1923
@test all(gzdotprod .≈ gfdotprod)
24+
@test all(gzdelta .≈ gfdelta)
25+
@test all(gzsinus .≈ gfsinus)
2026
end

0 commit comments

Comments
 (0)