Skip to content

Commit 0598c86

Browse files
committed
Made corrections on the FBM kernel
1 parent da9a191 commit 0598c86

File tree

4 files changed

+41
-11
lines changed

4 files changed

+41
-11
lines changed

src/trainable.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import .Flux.trainable
44

55
trainable(k::ConstantKernel) = (k.c,)
66

7+
trainable(k::FBMKernel) = (k.h,)
8+
79
trainable(k::GammaExponentialKernel) = (k.γ,)
810

911
trainable(k::GammaRationalQuadraticKernel) = (k.α, k.γ)

test/kernels/fbm.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
@testset "FBM" begin
2+
h = 0.3
3+
k = FBMKernel(h = h)
4+
v1 = rand(3); v2 = rand(3)
5+
@test k(v1,v2) (sqeuclidean(v1, zero(v1))^h + sqeuclidean(v2, zero(v2))^h - sqeuclidean(v1-v2, zero(v1-v2))^h)/2 atol=1e-5
6+
7+
# kernelmatrix tests
8+
m1 = rand(3,3)
9+
m2 = rand(3,3)
10+
@test kernelmatrix(k, m1, m1) kernelmatrix(k, m1) atol=1e-5
11+
@test kernelmatrix(k, m1, m2) k(m1, m2) atol=1e-5
12+
13+
14+
x1 = rand()
15+
x2 = rand()
16+
@test kernelmatrix(k, x1*ones(1,1), x2*ones(1,1))[1] k(x1, x2) atol=1e-5
17+
end

test/runtests.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,18 @@ using KernelFunctions: metric
6262
end
6363

6464
@testset "kernels" begin
65+
include(joinpath("kernels", "constant.jl"))
66+
include(joinpath("kernels", "cosine.jl"))
6567
include(joinpath("kernels", "exponential.jl"))
68+
include(joinpath("kernels", "exponentiated.jl"))
69+
include(joinpath("kernels", "fbm.jl"))
70+
include(joinpath("kernels", "kernelproduct.jl"))
71+
include(joinpath("kernels", "kernelsum.jl"))
6672
include(joinpath("kernels", "matern.jl"))
6773
include(joinpath("kernels", "polynomial.jl"))
68-
include(joinpath("kernels", "constant.jl"))
6974
include(joinpath("kernels", "rationalquad.jl"))
70-
include(joinpath("kernels", "exponentiated.jl"))
71-
include(joinpath("kernels", "cosine.jl"))
72-
include(joinpath("kernels", "transformedkernel.jl"))
7375
include(joinpath("kernels", "scaledkernel.jl"))
74-
include(joinpath("kernels", "kernelsum.jl"))
75-
include(joinpath("kernels", "kernelproduct.jl"))
76+
include(joinpath("kernels", "transformedkernel.jl"))
7677

7778
# Legacy tests that don't correspond to anything meaningful in src. Unclear how
7879
# helpful these are.

test/trainable.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,27 @@
11
@testset "trainable" begin
2-
ν = 2.0; c = 3.0; d = 2.0; γ = 2.0; α = 2.5
2+
ν = 2.0; c = 3.0; d = 2.0; γ = 2.0; α = 2.5; h = 0.5
3+
34
kc = ConstantKernel(c=c)
45
@test all(params(kc) .== params([c]))
5-
km = MaternKernel=ν)
6-
@test all(params(km) .== params([ν]))
7-
kl = LinearKernel(c=c)
8-
@test all(params(kl) .== params([c]))
6+
7+
kfbm = FBMKernel(h = h)
8+
@test all(params(kfbm) .== params([h]))
9+
910
kge = GammaExponentialKernel=γ)
1011
@test all(params(kge) .== params([γ]))
12+
1113
kgr = GammaRationalQuadraticKernel=γ, α=α)
1214
@test all(params(kgr) .== params([α], [γ]))
15+
16+
kl = LinearKernel(c=c)
17+
@test all(params(kl) .== params([c]))
18+
19+
km = MaternKernel=ν)
20+
@test all(params(km) .== params([ν]))
21+
1322
kp = PolynomialKernel(c=c, d=d)
1423
@test all(params(kp) .== params([d], [c]))
24+
1525
kr = RationalQuadraticKernel=α)
1626
@test all(params(kr) .== params([α]))
1727

0 commit comments

Comments
 (0)