Skip to content

Commit 3377137

Browse files
committed
Added test for checking parameters
1 parent 4c8c816 commit 3377137

File tree

3 files changed

+66
-21
lines changed

3 files changed

+66
-21
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,12 @@ julia = "1.0"
2525

2626
[extras]
2727
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
28+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
2829
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
2930
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
3031
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3132
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3233
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3334

3435
[targets]
35-
test = ["Random", "Test", "FiniteDifferences", "Zygote", "PDMats", "Kronecker"]
36+
test = ["Random", "Test", "FiniteDifferences", "Zygote", "PDMats", "Kronecker", "Flux"]

src/trainable.jl

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,39 @@
1-
using .Flux: trainable
2-
3-
Flux.trainable(::Kernel) = () # By default no parameters are returned
4-
Flux.trainable(::Transform) = ()
1+
import .Flux.trainable
52

63
### Base Kernels
74

8-
Flux.trainable(k::ConstantKernel) = (k.c,)
5+
trainable(k::ConstantKernel) = (k.c,)
96

10-
Flux.trainable(k::GammaExponentialKernel) = (γ,)
7+
trainable(k::GammaExponentialKernel) = (k.γ,)
118

12-
Flux.trainable(k::GammaRationalQuadraticKernel) = (k.α, k.γ)
9+
trainable(k::GammaRationalQuadraticKernel) = (k.α, k.γ)
1310

14-
Flux.trainable(k::MaternKernel) = (k.ν,)
11+
trainable(k::MaternKernel) = (k.ν,)
1512

16-
Flux.trainable(k::LinearKernel) = (k.c,)
13+
trainable(k::LinearKernel) = (k.c,)
1714

18-
Flux.trainable(k::PolynomialKernel) = (k.d, k.c)
15+
trainable(k::PolynomialKernel) = (k.d, k.c)
1916

20-
Flux.trainable(k::RationalQuadraticKernel) = (k.α,)
17+
trainable(k::RationalQuadraticKernel) = (k.α,)
2118

2219
#### Composite kernels
2320

24-
Flux.trainable::KernelProduct) = k.kernels
21+
trainable::KernelProduct) = κ.kernels
2522

26-
Flux.trainable::KernelSum) =.weights, κ.kernels) #To check
23+
trainable::KernelSum) =.weights, κ.kernels) #To check
2724

28-
Flux.trainable::ScaledKernel) =.σ, κ.kernel)
25+
trainable::ScaledKernel) =.σ, κ.kernel)
2926

30-
Flux.trainable::TransformedKernel) =.transform, κ.kernel)
27+
trainable::TransformedKernel) =.transform, κ.kernel)
3128

3229
### Transforms
3330

34-
Flux.trainable(t::ARDTransform) = (t.v,)
31+
trainable(t::ARDTransform) = (t.v,)
3532

36-
Flux.trainable(t::ChainTransform) = t.transforms
33+
trainable(t::ChainTransform) = t.transforms
3734

38-
Flux.trainable(t::FunctionTransform) = (t.f,)
35+
trainable(t::FunctionTransform) = (t.f,)
3936

40-
Flux.trainable(t::LowRankTransform) = (t.proj,)
37+
trainable(t::LowRankTransform) = (t.proj,)
4138

42-
Flux.trainable(t::ScaleTransform) = (t.s,)
39+
trainable(t::ScaleTransform) = (t.s,)

test/test_flux.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
using KernelFunctions
2+
using Test
3+
using Flux
4+
5+
@testset "Params" begin
6+
ν = 2.0; c = 3.0; d = 2.0; γ = 2.0; α = 2.5
7+
kc = ConstantKernel(c=c)
8+
@test all(params(kc) .== params([c]))
9+
km = MaternKernel=ν)
10+
@test all(params(km) .== params([ν]))
11+
kl = LinearKernel(c=c)
12+
@test all(params(kl) .== params([c]))
13+
kge = GammaExponentialKernel=γ)
14+
@test all(params(kge) .== params([γ]))
15+
kgr = GammaRationalQuadraticKernel=γ, α=α)
16+
@test all(params(kgr) .== params([α], [γ]))
17+
kp = PolynomialKernel(c=c, d=d)
18+
@test all(params(kp) .== params([d], [c]))
19+
kr = RationalQuadraticKernel=α)
20+
@test all(params(kr) .== params([α]))
21+
22+
k = km + kc
23+
@test all(params(k) .== params([k.weights], km, kc))
24+
25+
k = km * kc
26+
@test all(params(k) .== params(km, kc))
27+
28+
s = 2.0
29+
k = transform(km, s)
30+
@test all(params(k) .== params([s], km))
31+
32+
v = [2.0]
33+
k = transform(kc, v)
34+
@test all(params(k) .== params(v, kc))
35+
36+
P = rand(3, 2)
37+
k = transform(km,LowRankTransform(P))
38+
@test all(params(k) .== params(P, km))
39+
40+
k = transform(km, LowRankTransform(P) ScaleTransform(s))
41+
@test all(params(k) .== params([s], P, km))
42+
43+
c = Chain(Dense(3, 2))
44+
k = transform(km, FunctionTransform(c))
45+
@test all(params(k) .== params(c, km))
46+
47+
end

0 commit comments

Comments
 (0)