Skip to content

Commit b070c02

Browse files
committed
Corrected LowRankTransform Added TransformChain and added Tests
1 parent 5772186 commit b070c02

File tree

3 files changed

+65
-16
lines changed

3 files changed

+65
-16
lines changed

src/transform/lowranktransform.jl

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,12 @@
1-
struct LowRankTransform{T<:AbstractMatrix{<:Real}}} <: Transform
1+
struct LowRankTransform{T<:AbstractMatrix{<:Real}} <: Transform
22
proj::T
33
end
44

5-
function LowRankTransform(proj::M) where {M<:AbstractMatrix{<:Real}}
6-
# @check_args(LowRankTransform, proj, A > zero(T), "s > 0")
7-
LowRankTransform{T}(proj)
8-
end
9-
10-
Base.size(tr::LowRankTransform{<:Real},i::Int) = size(tr.proj,i)
11-
Base.size(tr::LowRankTransform{<:Real}) = size(tr.proj)
5+
Base.size(tr::LowRankTransform,i::Int) = size(tr.proj,i)
6+
Base.size(tr::LowRankTransform) = size(tr.proj)
127

13-
function transform(t::LowRankTransform{<:AbstractVector{<:Real}},X::AbstractMatrix{<:Real},obsdim::Int)
14-
@boundscheck if size(t,2) != size(X,!Bool(obsdim-1)+1)
8+
function transform(t::LowRankTransform,X::AbstractMatrix{<:Real},obsdim::Int)
9+
@boundscheck if size(t,2) != size(X,1)
1510
throw(DimensionMismatch("The projection matrix has size $(size(t)) and cannot be used on X with dimensions $(size(X))"))
1611
end
1712
_transform(t,X,obsdim)

src/transform/transform.jl

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,45 @@
1+
export Transform, ScaleTransform, LowRankTransform, FunctionTransform, TransformChain
2+
3+
14
abstract type Transform end
25

6+
include("scaletransform.jl")
7+
include("lowranktransform.jl")
8+
39
struct TransformChain <: Transform
410
transforms::Vector{Transform}
511
end
612

13+
Base.length(t::TransformChain) = length(t.transforms)
14+
715
function TransformChain(v::AbstractVector{<:Transform})
816
TransformChain(v)
917
end
1018

11-
struct InputTransform{F} <: Transform
12-
f::F
19+
function transform(t::TransformChain,X::T,obsdim::Int=1) where {T}
20+
Xtr = copy(X)
21+
for tr in t.transforms
22+
Xtr = transform(tr,Xtr,obsdim)
23+
end
24+
return Xtr
1325
end
1426

15-
# function InputTransform(f::F) where {F}
16-
# InputTransform{F}(f)
17-
# end
27+
Base.:(t₁::Transform,t₂::Transform) = TransformChain([t₂,t₁])
28+
Base.:(t::Transform,tc::TransformChain) = TransformChain(vcat(tc.transforms,t))
29+
Base.:(tc::TransformChain,t::Transform) = TransformChain(vcat(t,tc.transforms))
30+
31+
"""
32+
FunctionTransform
33+
34+
Take a function `f` as an argument which is going to act on the matrix as a whole.
35+
Make sure that `f` is supposed to act on a matrix by eventually using broadcasting
36+
For example `f(x)=sin(x)` -> `f(x)=sin.(x)`
37+
"""
38+
struct FunctionTransform{F} <: Transform
39+
f::F
40+
end
1841

19-
transform(t::InputTransform,x::T,obsdim::Int=1) where {T} = t.f(X)
42+
transform(t::FunctionTransform,X::T,obsdim::Int=1) where {T} = t.f(X)
2043

2144

2245
struct IdentityTransform <: Transform end

test/test_transform.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
using Test
2+
using KernelFunctions
3+
using Random: seed!
4+
5+
seed!(42)
6+
7+
dims = (10,5)
8+
X = rand(dims...)
9+
##
10+
s = 3.0
11+
v1 = vcat(3.0,4.0*ones(dims[2]-1))
12+
v2 = vcat(3.0,4.0*ones(dims[1]-1))
13+
t = ScaleTransform(s)
14+
vt1 = ScaleTransform(v1)
15+
vt2 = ScaleTransform(v2)
16+
@test all(KernelFunctions.transform(t,X).==s*X)
17+
@test all(KernelFunctions.transform(vt1,X,1).==v1'.*X)
18+
@test all(KernelFunctions.transform(vt2,X,2).==v2.*X)
19+
##
20+
P = rand(5,10)
21+
tp = LowRankTransform(P)
22+
@test all(KernelFunctions.transform(tp,X,2).==P*X)
23+
##
24+
f(x) = sin.(x)
25+
tf = FunctionTransform(f)
26+
@test all(KernelFunctions.transform(tf,X,1).==f(X))
27+
##
28+
tchain = TransformChain([t,tp,tf])
29+
ttptf
30+
TransformChain([t,tp])
31+
@test all(KernelFunctions.transform(tchain,X).==f(P*(s*X)))

0 commit comments

Comments
 (0)