Skip to content

Commit 2234d4f

Browse files
committed
Solved tests for Matern
1 parent df3819a commit 2234d4f

File tree

4 files changed

+23
-9
lines changed

4 files changed

+23
-9
lines changed

src/KernelFunctions.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ export spectral_mixture_kernel, spectral_mixture_product_kernel
3434
using Compat
3535
using Requires
3636
using Distances, LinearAlgebra
37-
using SpecialFunctions: logabsgamma, besselk
38-
using ZygoteRules: @adjoint
37+
using SpecialFunctions: logabsgamma, besselk, polygamma
38+
using ZygoteRules: @adjoint, pullback
3939
using StatsFuns: logtwo
4040
using InteractiveUtils: subtypes
4141
using StatsBase

src/basekernels/matern.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@ end
1717

1818
@inline function kappa::MaternKernel, d::Real)
1919
ν = first.ν)
20-
iszero(d) ? one(d) :
21-
exp(
22-
(one(d) - ν) * logtwo - logabsgamma(ν)[1] +
23-
ν * log(sqrt(2ν) * d) +
24-
log(besselk(ν, sqrt(2ν) * d))
25-
)
20+
iszero(d) ? one(d) : _matern(ν, d)
21+
end
22+
23+
function _matern::Real, d::Real)
24+
exp((one(d) - ν) * logtwo - loggamma(ν) + ν * log(sqrt(2ν) * d) + log(besselk(ν, sqrt(2ν) * d)))
2625
end
2726

2827
metric(::MaternKernel) = Euclidean()

src/utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
hadamard(x, y) = x .* y
22

3+
loggamma(x) = first(logabsgamma(x))
4+
35
# Macro for checking arguments
46
macro check_args(K, param, cond, desc=string(cond))
57
quote
@@ -124,4 +126,3 @@ function validate_dims(x::AbstractVector, y::AbstractVector)
124126
))
125127
end
126128
end
127-

src/zygote_adjoints.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,20 @@
44
end
55
end
66

7+
@adjoint function loggamma(x)
8+
first(logabsgamma(x)) , Δ ->.* polygamma(0, x), )
9+
end
10+
11+
@adjoint function kappa::MaternKernel, d::Real)
12+
ν = first.ν)
13+
val, grad = pullback(_matern, ν, d)
14+
return ((iszero(d) ? one(d) : val),
15+
Δ -> begin
16+
= grad(Δ)
17+
return ((ν = [∇[1]],), iszero(d) ? zero(d) : ∇[2])
18+
end)
19+
end
20+
721
@adjoint function ColVecs(X::AbstractMatrix)
822
back::NamedTuple) =.X,)
923
back::AbstractMatrix) = (Δ,)

0 commit comments

Comments
 (0)