Skip to content

Commit 395a629

Browse files
authored
Merge pull request #121 from devmotion/maternad
2 parents b2e2471 + 6a9ebea commit 395a629

File tree

4 files changed

+5
-20
lines changed

4 files changed

+5
-20
lines changed

src/KernelFunctions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ export spectral_mixture_kernel, spectral_mixture_product_kernel
3434
using Compat
3535
using Requires
3636
using Distances, LinearAlgebra
37-
using SpecialFunctions: logabsgamma, besselk, polygamma
37+
using SpecialFunctions: loggamma, besselk, polygamma
3838
using ZygoteRules: @adjoint, pullback
3939
using StatsFuns: logtwo
4040
using InteractiveUtils: subtypes

src/basekernels/matern.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@ struct MaternKernel{Tν<:Real} <: SimpleKernel
1616
end
1717

1818
@inline function kappa::MaternKernel, d::Real)
19-
ν = first.ν)
20-
iszero(d) ? one(d) : _matern(ν, d)
19+
result = _matern(first.ν), d)
20+
return ifelse(iszero(d), one(result), result)
2121
end
2222

2323
function _matern::Real, d::Real)
24-
exp((one(d) - ν) * logtwo - loggamma(ν) + ν * log(sqrt(2ν) * d) + log(besselk(ν, sqrt(2ν) * d)))
24+
y = sqrt(2ν) * d
25+
return exp((one(d) - ν) * logtwo - loggamma(ν) + ν * log(y) + log(besselk(ν, y)))
2526
end
2627

2728
metric(::MaternKernel) = Euclidean()

src/utils.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
hadamard(x, y) = x .* y
22

3-
loggamma(x) = first(logabsgamma(x))
4-
53
# Macro for checking arguments
64
macro check_args(K, param, cond, desc=string(cond))
75
quote

src/zygote_adjoints.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -59,20 +59,6 @@ end
5959
end
6060
end
6161

62-
@adjoint function loggamma(x)
63-
first(logabsgamma(x)) , Δ ->.* polygamma(0, x), )
64-
end
65-
66-
@adjoint function kappa::MaternKernel, d::Real)
67-
ν = first.ν)
68-
val, grad = pullback(_matern, ν, d)
69-
return ((iszero(d) ? one(d) : val),
70-
Δ -> begin
71-
= grad(Δ)
72-
return ((ν = [∇[1]],), iszero(d) ? zero(d) : ∇[2])
73-
end)
74-
end
75-
7662
@adjoint function ColVecs(X::AbstractMatrix)
7763
back::NamedTuple) =.X,)
7864
back::AbstractMatrix) = (Δ,)

0 commit comments

Comments
 (0)