Skip to content

Commit 4c14e26

Browse files
committed
Removed opt_params and moved functions from squeeze.jl to different files
1 parent c0d2e74 commit 4c14e26

12 files changed

+44
-35
lines changed

src/KernelFunctions.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module KernelFunctions
22

3-
export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!, kappa, kernelpdmat
4-
export get_params, set_params!
3+
export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!, kappa, kernelpdmat # Main matrix functions
4+
export params, duplicate, set! # Helpers
55

66
export Kernel
77
export ConstantKernel, WhiteKernel, ZeroKernel
@@ -12,8 +12,7 @@ export LinearKernel, PolynomialKernel
1212
export RationalQuadraticKernel, GammaRationalQuadraticKernel
1313
export KernelSum, KernelProduct
1414

15-
export SelectTransform, ChainTransform, ScaleTransform, LowRankTransform, IdentityTransform, FunctionTransform
16-
15+
export Transform, SelectTransform, ChainTransform, ScaleTransform, LowRankTransform, IdentityTransform, FunctionTransform
1716

1817
using Distances, LinearAlgebra
1918
using SpecialFunctions: lgamma, besselk
@@ -42,6 +41,5 @@ include("kernels/kernelsum.jl")
4241
include("kernels/kernelproduct.jl")
4342

4443
include("generic.jl")
45-
include("squeeze.jl")
4644

4745
end

src/kernels/kernelproduct.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ end
2020

2121
params(k::KernelProduct) = params.(k.kernels)
2222
opt_params(k::KernelProduct) = opt_params.(k.kernels)
23+
duplicate(k::KernelProduct,θ) = KernelProduct(duplicate.(k.kernels,θ))
2324

2425
Base.:*(k1::Kernel,k2::Kernel) = KernelProduct([k1,k2])
2526
Base.:*(k1::KernelProduct,k2::KernelProduct) = KernelProduct(vcat(k1.kernels,k2.kernels)) #TODO Add test

src/kernels/kernelsum.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ end
2727

2828
params(k::KernelSum) = (k.weights,params.(k.kernels))
2929
opt_params(k::KernelSum) = (k.weights,opt_params.(k.kernels))
30+
duplicate(k::KernelSum,θ) = KernelSum(duplicate.(k.kernels,θ[end]),weights=first(θ))
3031

3132
Base.:+(k1::Kernel,k2::Kernel) = KernelSum([k1,k2],weights=[1.0,1.0])
3233
Base.:+(k1::KernelSum,k2::KernelSum) = KernelSum(vcat(k1.kernels,k2.kernels),weights=vcat(k1.weights,k2.weights))

src/squeeze.jl

Lines changed: 0 additions & 24 deletions
This file was deleted.

src/transform/ardtransform.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ function set!(t::ARDTransform{T},ρ::AbstractVector{T}) where {T<:Real}
2626
end
2727

2828
params(t::ARDTransform) = t.v
29-
opt_params(t::ARDTransform) = t.v
3029
dim(t::ARDTransform) = length(t.v)
3130

3231
function transform(t::ARDTransform,X::AbstractMatrix{<:Real},obsdim::Int)

src/transform/chaintransform.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ end
3232

3333
set!(t::ChainTransform,θ) = set!.(t.transforms,θ)
3434
params(t::ChainTransform) = (params.(t.transforms))
35-
opt_params(t::ChainTransform) = (opt_params.(t.transforms))
35+
duplicate(t::ChainTransform,θ) = ChainTransform(duplicate.(t.transforms,θ))
36+
3637

3738
Base.:(t₁::Transform,t₂::Transform) = ChainTransform([t₂,t₁])
3839
Base.:(t::Transform,tc::ChainTransform) = ChainTransform(vcat(tc.transforms,t)) #TODO add test

src/transform/functiontransform.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ transform(t::FunctionTransform,X::T,obsdim::Int=defaultobs) where {T} = mapslice
1616

1717
params(t::FunctionTransform) = t.f
1818
opt_params(t::FunctionTransform) = nothing
19+
duplicate(t::FunctionTransform,θ) = t

src/transform/lowranktransform.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ function set!(t::LowRankTransform{<:AbstractMatrix{T}},M::AbstractMatrix{T}) whe
1717
end
1818

1919
params(t::LowRankTransform) = t.proj
20-
opt_params(t::LowRankTransform) = params(t)
2120

2221
Base.size(tr::LowRankTransform,i::Int) = size(tr.proj,i)
2322
Base.size(tr::LowRankTransform) = size(tr.proj) # TODO Add test

src/transform/scaletransform.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@ function ScaleTransform(s::T=1.0) where {T<:Real}
1616
end
1717

1818
set!(t::ScaleTransform::Real) = t.s .= [ρ]
19-
2019
params(t::ScaleTransform) = first(t.s)
21-
opt_params(t::ScaleTransform) = first(t.s)
2220
dim(str::ScaleTransform) = 1
2321

2422
transform(t::ScaleTransform,x::AbstractVecOrMat,obsdim::Int=defaultobs) = first(t.s) * x

src/transform/selecttransform.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ end
2323
set!(t::SelectTransform{<:AbstractVector{T}},dims::AbstractVector{T}) where {T<:Int} = t.select .= dims
2424

2525
params(t::SelectTransform) = t.select
26-
opt_params(t::SelectTransform) = nothing
26+
27+
duplicate(t::SelectTransform,θ) = t
28+
2729

2830
Base.maximum(t::SelectTransform) = maximum(t.select)
2931

0 commit comments

Comments
 (0)