Skip to content

Commit fced7cb

Browse files
committed
Updated tests and correction constructor
1 parent 5b9c395 commit fced7cb

File tree

5 files changed

+15
-11
lines changed

5 files changed

+15
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.2.1"
3+
version = "0.2.2"
44

55
[deps]
66
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"

src/KernelFunctions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ export KernelSum, KernelProduct
1515
export SelectTransform, ChainTransform, ScaleTransform, LowRankTransform, IdentityTransform, FunctionTransform
1616

1717

18-
using Distances, LinearAlgebra, StaticArrays
18+
using Distances, LinearAlgebra
1919
using SpecialFunctions: lgamma, besselk
2020
using StatsFuns: logtwo
2121
using PDMats: PDMat

src/kernels/exponential.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ end
7070

7171
function GammaExponentialKernel::AbstractVector{T₁},gamma::T₂=2.0) where {T₁<:Real,T₂<:Real}
7272
@check_args(GammaExponentialKernel, gamma, gamma >= zero(T₂), "gamma > 0")
73-
GammaExponentialKernel{T₁,ARDTransform{T₁,length(ρ)},T₂}(ScaleTransform(ρ),gamma)
73+
GammaExponentialKernel{T₁,ARDTransform{T₁,length(ρ)},T₂}(ARDTransform(ρ),gamma)
7474
end
7575

7676
function GammaExponentialKernel(t::Tr,gamma::T₁=2.0) where {Tr<:Transform,T₁<:Real}

test/test_constructors.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,21 @@ s = ScaleTransform(l)
99
@testset "SqExponentialKernel" begin
1010
@test KernelFunctions.metric(SqExponentialKernel(l)) == SqEuclidean()
1111
@test isequal(transform(SqExponentialKernel(l)),s)
12-
@test KernelFunctions.transform(SqExponentialKernel(vl)) == ScaleTransform(vl)
12+
@test KernelFunctions.transform(SqExponentialKernel(vl)) == ARDTransform(vl)
1313
@test isequal(KernelFunctions.transform(SqExponentialKernel(s)),s)
1414
end
1515

1616
## MaternKernel
17-
ScaleTransform{Base.RefValue{Float64}}(Base.RefValue{Float64}(2.0))
18-
ScaleTransform{Base.RefValue{Float64}}(Base.RefValue{Float64}(2.0))
1917
@testset "MaternKernel" begin
2018
@test KernelFunctions.metric(MaternKernel(l)) == Euclidean()
2119
@test KernelFunctions.metric(Matern32Kernel(l)) == Euclidean()
2220
@test KernelFunctions.metric(Matern52Kernel(l)) == Euclidean()
2321
@test isequal(KernelFunctions.transform(MaternKernel(l)),s)
2422
@test isequal(KernelFunctions.transform(Matern32Kernel(l)),s)
2523
@test isequal(KernelFunctions.transform(Matern52Kernel(l)),s)
26-
@test KernelFunctions.transform(MaternKernel(vl)) == ScaleTransform(vl)
27-
@test KernelFunctions.transform(Matern32Kernel(vl)) == ScaleTransform(vl)
28-
@test KernelFunctions.transform(Matern52Kernel(vl)) == ScaleTransform(vl)
24+
@test KernelFunctions.transform(MaternKernel(vl)) == ARDTransform(vl)
25+
@test KernelFunctions.transform(Matern32Kernel(vl)) == ARDTransform(vl)
26+
@test KernelFunctions.transform(Matern52Kernel(vl)) == ARDTransform(vl)
2927
@test KernelFunctions.transform(MaternKernel(s)) == s
3028
@test KernelFunctions.transform(Matern32Kernel(s)) == s
3129
@test KernelFunctions.transform(Matern52Kernel(s)) == s

test/test_transform.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,15 @@ f(x) = sin.(x)
1818
## Test Scale Transform
1919
@testset "ScaleTransform" begin
2020
t = ScaleTransform(s)
21-
vt1 = ScaleTransform(v1)
22-
vt2 = ScaleTransform(v2)
2321
@test all(KernelFunctions.transform(t,X).==s*X)
22+
s2 = 2.0
23+
KernelFunctions.set!(t,s2)
24+
@test all(t.s.==[s2])
25+
end
26+
## Test ARD Transform
27+
@testset "ARDTransform" begin
28+
vt1 = ARDTransform(v1)
29+
vt2 = ARDTransform(v2)
2430
@test all(KernelFunctions.transform(vt1,X,1).==v1'.*X)
2531
@test all(KernelFunctions.transform(vt2,X,2).==v2.*X)
2632
end

0 commit comments

Comments
 (0)