Skip to content

Commit aa27cd6

Browse files
committed
Added ColVecs
1 parent bb3e859 commit aa27cd6

File tree

6 files changed

+78
-2
lines changed

6 files changed

+78
-2
lines changed

src/transform/transform.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ Return exactly the input
1919
"""
2020
struct IdentityTransform <: Transform end
2121

22-
apply(t::IdentityTransform, x; obsdim::Int=defaultobs) = x
22+
apply(t::IdentityTransform, x; obsdim::Int = defaultobs) = x
23+
24+
apply(t::Transform, x::ColVecs; obsdim::Int = defaultobs) = ColVecs(apply(t, x.X, obsdim = 1))
2325

2426
### TODO Maybe defining adjoints could help but so far it's not working
2527

src/utils.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,30 @@ macro check_args(K, param, cond, desc=string(cond))
1010
end
1111

1212

13+
"""
14+
ColVecs{T, TX<:AbstractMatrix}
15+
16+
A lightweight box for an `AbstractMatrix` to make it behave like a vector of vectors.
17+
"""
18+
struct ColVecs{T, TX<:AbstractMatrix{T}} <: AbstractVector{Vector{T}}
19+
X::TX
20+
ColVecs(X::TX) where {T, TX<:AbstractMatrix{T}} = new{T, TX}(X)
21+
end
22+
23+
Base.:(==)(D1::ColVecs, D2::ColVecs) = D1.X == D2.X
24+
Base.size(D::ColVecs) = (size(D.X, 2),)
25+
Base.length(D::ColVecs) = size(D.X, 2)
26+
Base.getindex(D::ColVecs, n::Int) = D.X[:, n]
27+
Base.getindex(D::ColVecs, n::CartesianIndex{1}) = getindex(D, n[1])
28+
Base.getindex(D::ColVecs, n) = ColVecs(D.X[:, n])
29+
Base.view(D::ColVecs, n::Int) = view(D.X, :, n)
30+
Base.view(D::ColVecs, n) = ColVecs(view(D.X, :, n))
31+
Base.eltype(D::ColVecs{T}) where T = Vector{T}
32+
Base.zero(D::ColVecs) = ColVecs(zero(D.X))
33+
Base.iterate(D::ColVecs) = (view(D.X, :, 1), 2)
34+
Base.iterate(D::ColVecs, state) = state > length(D) ? nothing : (view(D.X, :, state), state + 1)
35+
36+
1337
# Take highest Float among possibilities
1438
# function promote_float(Tₖ::DataType...)
1539
# if length(Tₖ) == 0

src/zygote_adjoints.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,15 @@
44
end
55
end
66

7+
@adjoint function ColVecs(X::AbstractMatrix)
8+
back::NamedTuple) =.X,)
9+
back::AbstractMatrix) = (Δ,)
10+
function back::AbstractVector{<:AbstractVector{<:Real}})
11+
throw(error("In slow method"))
12+
end
13+
return ColVecs(X), back
14+
end
15+
716
# @adjoint function evaluate(s::Sinus, x::AbstractVector, y::AbstractVector)
817
# d = evaluate(s, x, y)
918
# s = sum(sin.(π*(x-y)))

test/trainable.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
@testset "trainable" begin
2+
using Flux: params
23
ν = 2.0; c = 3.0; d = 2.0; γ = 2.0; α = 2.5; h = 0.5; r = rand(3)
34

45
kc = ConstantKernel(c=c)

test/transform/transform.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,16 @@
33
rng = MersenneTwister(123546)
44
X = rand(rng, dims...)
55
@testset "IdentityTransform" begin
6-
@test KernelFunctions.apply(IdentityTransform(),X)==X
6+
@test KernelFunctions.apply(IdentityTransform(), X) == X
7+
end
8+
@testset "ColVecs" begin
9+
vX = KernelFunctions.ColVecs(X)
10+
t = ARDTransform(rand(dims[1]))
11+
@test KernelFunctions.apply(t, vX) KernelFunctions.ColVecs(KernelFunctions.apply(t, X, obsdim = 2))
12+
13+
Y = rand(rng, reverse(dims)...)
14+
vY = KernelFunctions.ColVecs(Y')
15+
t = ARDTransform(rand(dims[1]))
16+
@test KernelFunctions.apply(t, vY) KernelFunctions.ColVecs(KernelFunctions.apply(t, Y, obsdim = 1)')
717
end
818
end

test/utils.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,33 @@
11
@testset "utils" begin
2+
using KernelFunctions: ColVecs
3+
rng, N, D = MersenneTwister(123456), 10, 2
4+
x, X = randn(rng, N), randn(rng, D, N)
25

6+
# Test Matrix data sets.
7+
@testset "ColVecs" begin
8+
DX = ColVecs(X)
9+
@test DX == DX
10+
@test size(DX) == (N,)
11+
@test length(DX) == N
12+
@test getindex(DX, 5) isa Vector
13+
@test getindex(DX, 5) == X[:, 5]
14+
@test getindex(DX, 1:2:6) isa ColVecs
15+
@test getindex(DX, 1:2:6) == ColVecs(X[:, 1:2:6])
16+
@test view(DX, 4) isa AbstractVector
17+
@test view(DX, 4) == view(X, :, 4)
18+
@test view(DX, 1:2:4) isa ColVecs
19+
@test view(DX, 1:2:4) == ColVecs(view(X, :, 1:2:4))
20+
@test eltype(DX) == Vector{Float64}
21+
@test eachindex(DX) == 1:N
22+
23+
let
24+
@test Zygote.pullback(ColVecs, X)[1] == DX
25+
DX, back = Zygote.pullback(ColVecs, X)
26+
@test back((X=ones(size(X)),))[1] == ones(size(X))
27+
28+
@test Zygote.pullback(DX->DX.X, DX)[1] == X
29+
X_, back = Zygote.pullback(DX->DX.X, DX)
30+
@test back(ones(size(X)))[1].X == ones(size(X))
31+
end
32+
end
333
end

0 commit comments

Comments
 (0)