diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 5ffa58654..6844e7ae1 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -34,7 +34,7 @@ export spectral_mixture_kernel, spectral_mixture_product_kernel using Compat using Requires using Distances, LinearAlgebra -using SpecialFunctions: logabsgamma, besselk, polygamma +using SpecialFunctions: loggamma, besselk, polygamma using ZygoteRules: @adjoint, pullback using StatsFuns: logtwo using InteractiveUtils: subtypes diff --git a/src/basekernels/matern.jl b/src/basekernels/matern.jl index 44b5eb989..13fb455f6 100644 --- a/src/basekernels/matern.jl +++ b/src/basekernels/matern.jl @@ -16,12 +16,13 @@ struct MaternKernel{Tν<:Real} <: SimpleKernel end @inline function kappa(κ::MaternKernel, d::Real) - ν = first(κ.ν) - iszero(d) ? one(d) : _matern(ν, d) + result = _matern(first(κ.ν), d) + return ifelse(iszero(d), one(result), result) end function _matern(ν::Real, d::Real) - exp((one(d) - ν) * logtwo - loggamma(ν) + ν * log(sqrt(2ν) * d) + log(besselk(ν, sqrt(2ν) * d))) + y = sqrt(2ν) * d + return exp((one(d) - ν) * logtwo - loggamma(ν) + ν * log(y) + log(besselk(ν, y))) end metric(::MaternKernel) = Euclidean() diff --git a/src/utils.jl b/src/utils.jl index ed11f2428..a6239e88a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,7 +1,5 @@ hadamard(x, y) = x .* y -loggamma(x) = first(logabsgamma(x)) - # Macro for checking arguments macro check_args(K, param, cond, desc=string(cond)) quote diff --git a/src/zygote_adjoints.jl b/src/zygote_adjoints.jl index a95be8142..2b79750d5 100644 --- a/src/zygote_adjoints.jl +++ b/src/zygote_adjoints.jl @@ -59,20 +59,6 @@ end end end -@adjoint function loggamma(x) - first(logabsgamma(x)) , Δ -> (Δ .* polygamma(0, x), ) -end - -@adjoint function kappa(κ::MaternKernel, d::Real) - ν = first(κ.ν) - val, grad = pullback(_matern, ν, d) - return ((iszero(d) ? one(d) : val), - Δ -> begin - ∇ = grad(Δ) - return ((ν = [∇[1]],), iszero(d) ? zero(d) : ∇[2]) - end) -end - @adjoint function ColVecs(X::AbstractMatrix) back(Δ::NamedTuple) = (Δ.X,) back(Δ::AbstractMatrix) = (Δ,)