Skip to content

Commit f9d9bd0

Browse files
committed
Moved trainable to trainable.jl and added Flux via Requires
1 parent 2f1229d commit f9d9bd0

19 files changed

+50
-47
lines changed

src/KernelFunctions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ include("zygote_adjoints.jl")
5858
function __init__()
5959
@require Kronecker="2c470bb0-bcc8-11e8-3dad-c9649493f05e" include("matrix/kernelkroneckermat.jl")
6060
@require PDMats="90014a1f-27ba-587c-ab20-58faa44d9150" include("matrix/kernelpdmat.jl")
61+
@require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" include("trainable.jl")
6162
end
6263

6364
end

src/generic.jl

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,7 @@ _scale(t::ScaleTransform, metric, x, y) = evaluate(metric, apply(t, x), apply(t,
1414
printshifted(io::IO::Kernel,shift::Int) = print(io,"")
1515
Base.show(io::IO::Kernel) = print(io,nameof(typeof(κ)))
1616

17-
function params(k::Kernel)
18-
ps = []
19-
params!(ps,k)
20-
return ps
21-
end
22-
23-
function params!(ps,k::Kernel)
24-
for child in trainable(k)
25-
params!(ps,k)
26-
end
27-
end
28-
29-
params!(ps,x::AbstractArray) = push!(ps,x)
30-
31-
trainable(x) = ()
17+
_trainable(x) = ()
3218

3319
### Syntactic sugar for creating matrices and using kernel functions
3420
for k in subtypes(BaseKernel)

src/kernels/constant.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ struct ConstantKernel{Tc<:Real} <: BaseKernel
4141
end
4242
end
4343

44-
trainable(k::ConstantKernel) = (k.c,)
45-
4644
kappa::ConstantKernel,x::Real) = first.c)*one(x)
4745

4846
metric(::ConstantKernel) = Delta()

src/kernels/exponential.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ struct GammaExponentialKernel{Tγ<:Real} <: BaseKernel
5454
end
5555
end
5656

57-
trainable(k::GammaExponentialKernel) = (γ,)
58-
5957
kappa::GammaExponentialKernel, d²::Real) = exp(-^first.γ))
6058
iskroncompatible(::GammaExponentialKernel) = true
6159
metric(::GammaExponentialKernel) = SqEuclidean()

src/kernels/kernelproduct.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@ struct KernelProduct <: Kernel
1414
kernels::Vector{Kernel}
1515
end
1616

17-
params(k::KernelProduct) = params.(k.kernels)
18-
opt_params(k::KernelProduct) = opt_params.(k.kernels)
19-
2017
Base.:*(k1::Kernel,k2::Kernel) = KernelProduct([k1,k2])
2118
Base.:*(k1::KernelProduct,k2::KernelProduct) = KernelProduct(vcat(k1.kernels,k2.kernels)) #TODO Add test
2219
Base.:*(k::Kernel,kp::KernelProduct) = KernelProduct(vcat(k,kp.kernels))

src/kernels/kernelsum.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,6 @@ function KernelSum(
2525
KernelSum(kernels, weights)
2626
end
2727

28-
params(k::KernelSum) = (k.weights, params.(k.kernels))
29-
opt_params(k::KernelSum) = (k.weights, opt_params.(k.kernels))
30-
3128
Base.:+(k1::Kernel, k2::Kernel) = KernelSum([k1, k2], weights = [1.0, 1.0])
3229
Base.:+(k1::ScaledKernel, k2::ScaledKernel) = KernelSum([kernel(k1), kernel(k2)], weights = [first(k1.σ), first(k2.σ)])
3330
Base.:+(k1::KernelSum, k2::KernelSum) =

src/kernels/matern.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ struct MaternKernel{Tν<:Real} <: BaseKernel
1414
end
1515
end
1616

17-
trainable(k::MaternKernel) = (k.ν,)
18-
1917
@inline function kappa::MaternKernel, d::Real)
2018
ν = first.ν)
2119
iszero(d) ? one(d) : exp((one(d)-ν)*logtwo-logabsgamma(ν)[1] + ν*log(sqrt(2ν)*d)+log(besselk(ν,sqrt(2ν)*d)))

src/kernels/polynomial.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@ struct LinearKernel{Tc<:Real} <: BaseKernel
1313
end
1414
end
1515

16-
trainable(k::LinearKernel) = (k.c,)
17-
1816
kappa::LinearKernel, xᵀy::Real) = xᵀy + first.c)
1917

2018
metric(::LinearKernel) = DotProduct()
@@ -36,8 +34,6 @@ struct PolynomialKernel{Td<:Real, Tc<:Real} <: BaseKernel
3634
end
3735
end
3836

39-
trainable(k::PolynomialKernel) = (k.d,k.c)
40-
4137
kappa::PolynomialKernel, xᵀy::T) where {T<:Real} = (xᵀy + first.c))^(first.d))
4238

4339
metric(::PolynomialKernel) = DotProduct()

src/kernels/rationalquad.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ struct RationalQuadraticKernel{Tα<:Real} <: BaseKernel
1414
end
1515
end
1616

17-
trainable(k::RationalQuadraticKernel) = (k.α,)
18-
1917
kappa::RationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+/first.α))^(-first.α))
2018

2119
metric(::RationalQuadraticKernel) = SqEuclidean()
@@ -38,8 +36,6 @@ struct GammaRationalQuadraticKernel{Tα<:Real, Tγ<:Real} <: BaseKernel
3836
end
3937
end
4038

41-
trainable(k::GammaRationalQuadraticKernel) = (k.α,k.γ)
42-
4339
kappa::GammaRationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+^first.γ)/first.α))^(-first.α))
4440

4541
metric(::GammaRationalQuadraticKernel) = SqEuclidean()

src/kernels/scaledkernel.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ kappa(k::ScaledKernel, x) = first(k.σ)*kappa(k.kernel, x)
1212

1313
metric(k::ScaledKernel) = metric(k.kernel)
1414

15-
trainable(k::ScaledKernel) = (k.σ,k.kernel)
16-
1715
Base.:*(w::Real,k::Kernel) = ScaledKernel(k,w)
1816

1917
Base.show(io::IO::ScaledKernel) = printshifted(io,κ,0)

0 commit comments

Comments
 (0)