Skip to content

Commit 07631b6

Browse files
committed
Tests passing for constant kernels/modified Zygote to return zeros instead of nothing
1 parent 44ad0cd commit 07631b6

File tree

5 files changed

+108
-69
lines changed

5 files changed

+108
-69
lines changed

src/distances/delta.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
struct Delta <: Distances.PreMetric
22
end
33

4-
@inline function Distances._evaluate(::Delta,a::AbstractVector{T},b::AbstractVector{T}) where {T}
4+
@inline function Distances._evaluate(::Delta, a::AbstractVector, b::AbstractVector) where {T}
55
@boundscheck if length(a) != length(b)
66
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
77
end
88
return a == b
99
end
1010

11+
Distances.result_type(::Delta, Ta::Type, Tb::Type) = promote_type(Ta, Tb)
12+
1113
@inline (dist::Delta)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b)
12-
@inline (dist::Delta)(a::Number,b::Number) = a == b
14+
@inline (dist::Delta)(a::Number, b::Number) = a == b

src/zygote_adjoints.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,29 @@
1+
## Adjoints Delta
2+
@adjoint function evaluate(s::Delta, x::AbstractVector, y::AbstractVector)
3+
evaluate(s, x, y), Δ -> begin
4+
(nothing, nothing, nothing)
5+
end
6+
end
7+
8+
@adjoint function pairwise(d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
9+
D = pairwise(d, X, Y; dims = dims)
10+
if dims == 1
11+
return D, Δ -> (nothing, nothing, nothing)
12+
else
13+
return D, Δ -> (nothing, nothing, nothing)
14+
end
15+
end
16+
17+
@adjoint function pairwise(d::Delta, X::AbstractMatrix; dims=2)
18+
D = pairwise(d, X; dims = dims)
19+
if dims == 1
20+
return D, Δ -> (nothing, nothing)
21+
else
22+
return D, Δ -> (nothing, nothing)
23+
end
24+
end
25+
26+
## Adjoints DotProduct
127
@adjoint function evaluate(s::DotProduct, x::AbstractVector, y::AbstractVector)
228
dot(x, y), Δ -> begin
329
(nothing, Δ .* y, Δ .* x)
@@ -22,6 +48,7 @@ end
2248
end
2349
end
2450

51+
## Adjoints Sinus
2552
@adjoint function evaluate(s::Sinus, x::AbstractVector, y::AbstractVector)
2653
d = (x - y)
2754
sind = sinpi.(d)

test/basekernels/constant.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
@test kappa(k,2.0) == 0.0
66
@test KernelFunctions.metric(ZeroKernel()) == KernelFunctions.Delta()
77
@test repr(k) == "Zero Kernel"
8+
test_AD("Zero", ZeroKernel)
89
end
910
@testset "WhiteKernel" begin
1011
k = WhiteKernel()
@@ -14,6 +15,7 @@
1415
@test EyeKernel == WhiteKernel
1516
@test metric(WhiteKernel()) == KernelFunctions.Delta()
1617
@test repr(k) == "White Kernel"
18+
test_AD("WhiteKernel", WhiteKernel)
1719
end
1820
@testset "ConstantKernel" begin
1921
c = 2.0
@@ -24,5 +26,6 @@
2426
@test metric(ConstantKernel()) == KernelFunctions.Delta()
2527
@test metric(ConstantKernel(c=2.0)) == KernelFunctions.Delta()
2628
@test repr(k) == "Constant Kernel (c = $(c))"
29+
test_AD("ConstantKernel", c->ConstantKernel(c=first(c)), [c])
2730
end
2831
end

test/runtests.jl

Lines changed: 56 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
using KernelFunctions
22
using Distances
3-
using FiniteDifferences
4-
using Flux
53
using Kronecker
64
using LinearAlgebra
75
using PDMats
86
using Random
97
using SpecialFunctions
108
using Test
11-
using Zygote
9+
import Zygote, ForwardDiff, ReverseDiff, FiniteDifferences
1210

1311
using KernelFunctions: metric, kappa
1412

@@ -45,66 +43,66 @@ using KernelFunctions: metric, kappa
4543
@testset "KernelFunctions" begin
4644

4745
include("utils.jl")
48-
49-
@testset "distances" begin
50-
include(joinpath("distances", "dotproduct.jl"))
51-
include(joinpath("distances", "delta.jl"))
52-
include(joinpath("distances", "sinus.jl"))
53-
end
54-
55-
@testset "transform" begin
56-
include(joinpath("transform", "transform.jl"))
57-
include(joinpath("transform", "scaletransform.jl"))
58-
include(joinpath("transform", "ardtransform.jl"))
59-
include(joinpath("transform", "lineartransform.jl"))
60-
include(joinpath("transform", "functiontransform.jl"))
61-
include(joinpath("transform", "selecttransform.jl"))
62-
include(joinpath("transform", "chaintransform.jl"))
63-
end
46+
include("utils_AD.jl")
47+
# @testset "distances" begin
48+
# include(joinpath("distances", "dotproduct.jl"))
49+
# include(joinpath("distances", "delta.jl"))
50+
# include(joinpath("distances", "sinus.jl"))
51+
# end
52+
#
53+
# @testset "transform" begin
54+
# include(joinpath("transform", "transform.jl"))
55+
# include(joinpath("transform", "scaletransform.jl"))
56+
# include(joinpath("transform", "ardtransform.jl"))
57+
# include(joinpath("transform", "lineartransform.jl"))
58+
# include(joinpath("transform", "functiontransform.jl"))
59+
# include(joinpath("transform", "selecttransform.jl"))
60+
# include(joinpath("transform", "chaintransform.jl"))
61+
# end
6462

6563
@testset "basekernels" begin
6664
include(joinpath("basekernels", "constant.jl"))
67-
include(joinpath("basekernels", "cosine.jl"))
68-
include(joinpath("basekernels", "exponential.jl"))
69-
include(joinpath("basekernels", "exponentiated.jl"))
70-
include(joinpath("basekernels", "fbm.jl"))
71-
include(joinpath("basekernels", "gabor.jl"))
72-
include(joinpath("basekernels", "maha.jl"))
73-
include(joinpath("basekernels", "matern.jl"))
74-
include(joinpath("basekernels", "nn.jl"))
75-
include(joinpath("basekernels", "periodic.jl"))
76-
include(joinpath("basekernels", "polynomial.jl"))
77-
include(joinpath("basekernels", "piecewisepolynomial.jl"))
78-
include(joinpath("basekernels", "rationalquad.jl"))
79-
include(joinpath("basekernels", "sm.jl"))
80-
include(joinpath("basekernels", "wiener.jl"))
81-
end
82-
83-
@testset "kernels" begin
84-
include(joinpath("kernels", "kernelproduct.jl"))
85-
include(joinpath("kernels", "kernelsum.jl"))
86-
include(joinpath("kernels", "scaledkernel.jl"))
87-
include(joinpath("kernels", "tensorproduct.jl"))
88-
include(joinpath("kernels", "transformedkernel.jl"))
89-
90-
# Legacy tests that don't correspond to anything meaningful in src. Unclear how
91-
# helpful these are.
92-
include(joinpath("kernels", "custom.jl"))
93-
end
94-
95-
@testset "matrix" begin
96-
include(joinpath("matrix", "kernelmatrix.jl"))
97-
include(joinpath("matrix", "kernelkroneckermat.jl"))
98-
include(joinpath("matrix", "kernelpdmat.jl"))
99-
end
100-
101-
@testset "approximations" begin
102-
include(joinpath("approximations", "nystrom.jl"))
65+
# include(joinpath("basekernels", "cosine.jl"))
66+
# include(joinpath("basekernels", "exponential.jl"))
67+
# include(joinpath("basekernels", "exponentiated.jl"))
68+
# include(joinpath("basekernels", "fbm.jl"))
69+
# include(joinpath("basekernels", "gabor.jl"))
70+
# include(joinpath("basekernels", "maha.jl"))
71+
# include(joinpath("basekernels", "matern.jl"))
72+
# include(joinpath("basekernels", "nn.jl"))
73+
# include(joinpath("basekernels", "periodic.jl"))
74+
# include(joinpath("basekernels", "polynomial.jl"))
75+
# include(joinpath("basekernels", "piecewisepolynomial.jl"))
76+
# include(joinpath("basekernels", "rationalquad.jl"))
77+
# include(joinpath("basekernels", "sm.jl"))
78+
# include(joinpath("basekernels", "wiener.jl"))
10379
end
10480

105-
include("generic.jl")
106-
include("zygote_adjoints.jl")
107-
include("trainable.jl")
81+
# @testset "kernels" begin
82+
# include(joinpath("kernels", "kernelproduct.jl"))
83+
# include(joinpath("kernels", "kernelsum.jl"))
84+
# include(joinpath("kernels", "scaledkernel.jl"))
85+
# include(joinpath("kernels", "tensorproduct.jl"))
86+
# include(joinpath("kernels", "transformedkernel.jl"))
87+
#
88+
# # Legacy tests that don't correspond to anything meaningful in src. Unclear how
89+
# # helpful these are.
90+
# include(joinpath("kernels", "custom.jl"))
91+
# end
92+
#
93+
# @testset "matrix" begin
94+
# include(joinpath("matrix", "kernelmatrix.jl"))
95+
# include(joinpath("matrix", "kernelkroneckermat.jl"))
96+
# include(joinpath("matrix", "kernelpdmat.jl"))
97+
# end
98+
#
99+
# @testset "approximations" begin
100+
# include(joinpath("approximations", "nystrom.jl"))
101+
# end
102+
#
103+
# include("generic.jl")
104+
# include("zygote_adjoints.jl")
105+
# include("trainable.jl")
108106
end
109107

110108
# These are legacy tests that I'm not getting rid of, as they appear to be useful, but

test/utils_AD.jl

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
12
FDM = FiniteDifferences.central_fdm(5, 1)
23

34
function gradient(::Val{:Zygote}, f::Function, args)
4-
first(Zygote.gradient(f, args))
5-
end
6-
7-
function gradient(::Val{:Zygote}, f::Function, args::Zygote.Params)
8-
Zygote.gradient(f, args)
5+
g = first(Zygote.gradient(f, args))
6+
if isnothing(g)
7+
return zeros(size(args)) # To respect the same output as other ADs
8+
else
9+
return g
10+
end
911
end
1012

1113
function gradient(::Val{:ForwardDiff}, f::Function, args)
@@ -24,14 +26,22 @@ end
2426
testfunction(k, A, B, dim) = sum(kernelmatrix(k, A, B, obsdim = dim))
2527
testfunction(k, A, dim) = sum(kernelmatrix(k, A, obsdim = dim))
2628

27-
function test_FiniteDiff(kernelname, kernelfunction, args = nothing)
29+
function test_AD(kernelname::String, kernelfunction, args = nothing; ADs = [:Zygote, :ForwardDiff, :ReverseDiff], dims = [3, 3])
30+
test_fd = test_FiniteDiff(kernelname, kernelfunction, args, dims)
31+
if !test_fd.anynonpass
32+
for AD in ADs
33+
test_AD(AD, kernelname, kernelfunction, args, dims)
34+
end
35+
end
36+
end
37+
38+
function test_FiniteDiff(kernelname, kernelfunction, args = nothing, dims = [3, 3])
2839
# Init arguments :
2940
k = if args === nothing
3041
kernelfunction()
3142
else
3243
kernelfunction(args)
3344
end
34-
dims = [3, 3]
3545
rng = MersenneTwister(42)
3646
@testset "FiniteDifferences with $(kernelname)" begin
3747
if k isa SimpleKernel
@@ -60,10 +70,9 @@ function test_FiniteDiff(kernelname, kernelfunction, args = nothing)
6070
end
6171
end
6272

63-
function test_AD(AD, kernelname, kernelfunction, args = nothing)
73+
function test_AD(AD, kernelname, kernelfunction, args = nothing, dims = [3, 3])
6474
@testset "Testing $(kernelname) with AD : $(AD)" begin
6575
# Test kappa function
66-
dims = [3, 3]
6776
k = if args === nothing
6877
kernelfunction()
6978
else

0 commit comments

Comments
 (0)