Skip to content

Commit 472a9c3

Browse files
authored
Merge branch 'master' into functors
2 parents 91f8c76 + a92fcc2 commit 472a9c3

File tree

6 files changed

+106
-13
lines changed

6 files changed

+106
-13
lines changed

src/basekernels/maha.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ where the matrix P is the metric.
1010
"""
1111
struct MahalanobisKernel{T<:Real, A<:AbstractMatrix{T}} <: SimpleKernel
1212
P::A
13-
function MahalanobisKernel(P::AbstractMatrix{T}) where {T<:Real}
13+
function MahalanobisKernel(; P::AbstractMatrix{T}) where {T<:Real}
1414
LinearAlgebra.checksquare(P)
1515
new{T,typeof(P)}(P)
1616
end

src/transform/selecttransform.jl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
SelectTransform(dims::AbstractVector{Int})
2+
SelectTransform(dims)
33
44
Select the dimensions `dims` that the kernel is applied to.
55
```
@@ -9,17 +9,11 @@ Select the dimensions `dims` that the kernel is applied to.
99
transform(tr,X,obsdim=2) == X[dims,:]
1010
```
1111
"""
12-
struct SelectTransform{T<:AbstractVector{Int}} <: Transform
12+
struct SelectTransform{T} <: Transform
1313
select::T
14-
function SelectTransform{V}(dims::V) where {V<:AbstractVector{Int}}
15-
@assert all(dims .> 0) "Selective dimensions should all be positive integers"
16-
return new{V}(dims)
17-
end
1814
end
1915

20-
SelectTransform(x::T) where {T<:AbstractVector{Int}} = SelectTransform{T}(x)
21-
22-
set!(t::SelectTransform{<:AbstractVector{T}}, dims::AbstractVector{T}) where {T<:Int} = t.select .= dims
16+
set!(t::SelectTransform, dims) = t.select .= dims
2317

2418
duplicate(t::SelectTransform,θ) = t
2519

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
23
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
34
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
45
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
@@ -13,6 +14,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1314
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1415

1516
[compat]
17+
AxisArrays = "0.4.3"
1618
Distances = "0.9"
1719
FiniteDifferences = "0.10.8"
1820
Flux = "0.10, 0.11"

test/basekernels/maha.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
v2 = rand(rng, 3)
66

77
P = rand(rng, 3, 3)
8-
k = MahalanobisKernel(P)
8+
k = MahalanobisKernel(P=P)
99

1010
@test kappa(k, x) == exp(-x)
1111
@test k(v1, v2) exp(-sqmahalanobis(v1, v2, P))
1212
@test kappa(ExponentialKernel(), x) == kappa(k, x)
1313
@test repr(k) == "Mahalanobis Kernel (size(P) = $(size(P)))"
14-
# test_ADs(P -> MahalanobisKernel(P), P)
14+
# test_ADs(P -> MahalanobisKernel(P=P), P)
1515
@test_broken "Nothing passes (problem with Mahalanobis distance in Distances)"
1616

1717
test_params(k, (P,))

test/runtests.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using KernelFunctions
2+
using AxisArrays
23
using Distances
34
using Kronecker
45
using LinearAlgebra
@@ -58,12 +59,19 @@ include("test_utils.jl")
5859

5960
@testset "transform" begin
6061
include(joinpath("transform", "transform.jl"))
62+
print(" ")
6163
include(joinpath("transform", "scaletransform.jl"))
64+
print(" ")
6265
include(joinpath("transform", "ardtransform.jl"))
66+
print(" ")
6367
include(joinpath("transform", "lineartransform.jl"))
68+
print(" ")
6469
include(joinpath("transform", "functiontransform.jl"))
70+
print(" ")
6571
include(joinpath("transform", "selecttransform.jl"))
72+
print(" ")
6673
include(joinpath("transform", "chaintransform.jl"))
74+
print(" ")
6775
end
6876
@info "Ran tests on Transform"
6977

test/transform/selecttransform.jl

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,105 @@
88
x_cols = ColVecs(randn(rng, maximum(select), 6))
99
x_rows = RowVecs(randn(rng, 4, maximum(select)))
1010

11-
@testset "$(typeof(x))" for x in [x_vecs, x_cols, x_rows]
11+
Xs = [x_vecs, x_cols, x_rows]
12+
13+
@testset "$(typeof(x))" for x in Xs
1214
x′ = map(t, x)
1315
@test all([t(x[n]) == x[n][select] for n in eachindex(x)])
1416
@test all([t(x[n]) == x′[n] for n in eachindex(x)])
1517
end
1618

19+
symbols = [:a, :b, :c, :d, :e]
20+
select_symbols = [:a, :c, :e]
21+
22+
ts = SelectTransform(select_symbols)
23+
24+
a_vecs = map(x->AxisArray(x, col=symbols), x_vecs)
25+
a_cols = ColVecs(AxisArray(x_cols.X, col=symbols, index=(1:6)))
26+
a_rows = RowVecs(AxisArray(x_rows.X, index=(1:4), col=symbols))
27+
28+
As = [a_vecs, a_cols, a_rows]
29+
30+
@testset "$(typeof(a))" for (a, x) in zip(As, Xs)
31+
a′ = map(ts, a)
32+
x′ = map(t, x)
33+
@test a′ == x′
34+
end
35+
1736
select2 = [2, 3, 5]
1837
KernelFunctions.set!(t, select2)
1938
@test t.select == select2
2039

40+
select_symbols2 = [:b, :c, :e]
41+
KernelFunctions.set!(ts, select_symbols2)
42+
@test ts.select == select_symbols2
43+
2144
@test repr(t) == "Select Transform (dims: $(select2))"
45+
@test repr(ts) == "Select Transform (dims: $(select_symbols2))"
46+
2247
test_ADs(()->transform(SEKernel(), SelectTransform([1,2])))
48+
49+
X = randn(rng, (4, 3))
50+
A = AxisArray(X, row=[:a, :b, :c, :d], col=[:x, :y, :z])
51+
Y = randn(rng, (4, 2))
52+
B = AxisArray(Y, row=[:a, :b, :c, :d], col=[:v, :w])
53+
Z = randn(rng, (2, 3))
54+
C = AxisArray(Z, row=[:e, :f], col=[:x, :y, :z])
55+
56+
tx_row = transform(SEKernel(), SelectTransform([1,2,4]))
57+
ta_row = transform(SEKernel(), SelectTransform([:a,:b,:d]))
58+
tx_col = transform(SEKernel(), SelectTransform([1,3]))
59+
ta_col = transform(SEKernel(), SelectTransform([:x,:z]))
60+
61+
@test kernelmatrix(tx_row, X, obsdim=2) == kernelmatrix(ta_row, A, obsdim=2)
62+
@test kernelmatrix(tx_col, X, obsdim=1) == kernelmatrix(ta_col, A, obsdim=1)
63+
64+
@test kernelmatrix(tx_row, X, Y, obsdim=2) == kernelmatrix(ta_row, A, B, obsdim=2)
65+
@test kernelmatrix(tx_col, X, Z, obsdim=1) == kernelmatrix(ta_col, A, C, obsdim=1)
66+
67+
@testset "$(AD)" for AD in [:Zygote, :ForwardDiff]
68+
gx = gradient(AD, X) do x
69+
testfunction(tx_row, x, 2)
70+
end
71+
ga = gradient(AD, A) do a
72+
testfunction(ta_row, a, 2)
73+
end
74+
@test gx == ga
75+
gx = gradient(AD, X) do x
76+
testfunction(tx_col, x, 1)
77+
end
78+
ga = gradient(AD, A) do a
79+
testfunction(ta_col, a, 1)
80+
end
81+
@test gx == ga
82+
gx = gradient(AD, X) do x
83+
testfunction(tx_row, x, Y, 2)
84+
end
85+
ga = gradient(AD, A) do a
86+
testfunction(ta_row, a, B, 2)
87+
end
88+
@test gx == ga
89+
gx = gradient(AD, X) do x
90+
testfunction(tx_col, x, Z, 1)
91+
end
92+
ga = gradient(AD, A) do a
93+
testfunction(ta_col, a, C, 1)
94+
end
95+
@test gx == ga
96+
end
97+
98+
@testset "$(AD)" for AD in [:ReverseDiff]
99+
@test_broken ga = gradient(AD, A) do a
100+
testfunction(ta_row, a, 2)
101+
end
102+
@test_broken ga = gradient(AD, A) do a
103+
testfunction(ta_col, a, 1)
104+
end
105+
@test_broken ga = gradient(AD, A) do a
106+
testfunction(ta_row, a, B, 2)
107+
end
108+
@test_broken ga = gradient(AD, A) do a
109+
testfunction(ta_col, a, C, 1)
110+
end
111+
end
23112
end

0 commit comments

Comments
 (0)