Skip to content

Commit 462c43f

Browse files
committed
Use LinearTransform instead of LowRankTransform
1 parent 754bee4 commit 462c43f

File tree

11 files changed

+54
-52
lines changed

11 files changed

+54
-52
lines changed

docs/create_kernel_plots.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ p = heatmap(K2,yflip=true,colorbar=false,framestyle=:none,background_color=RGBA(
2323
savefig(joinpath(@__DIR__,"src","assets","heatmap_matern.png"))
2424

2525

26-
k = transform(PolynomialKernel(c=0.0,d=2.0),LowRankTransform(randn(3,1)))
26+
k = transform(PolynomialKernel(c=0.0,d=2.0), LinearTransform(randn(3,1)))
2727
K3 = kernelmatrix(k,xrange,obsdim=1)
2828
p = heatmap(K3,yflip=true,colorbar=false,framestyle=:none,background_color=RGBA(0.0,0.0,0.0,0.0))
2929
savefig(joinpath(@__DIR__,"src","assets","heatmap_poly.png"))

src/KernelFunctions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ export KernelSum, KernelProduct
2323
export TransformedKernel, ScaledKernel
2424
export TensorProduct
2525

26-
export Transform, SelectTransform, ChainTransform, ScaleTransform, LowRankTransform, IdentityTransform, FunctionTransform
26+
export Transform, SelectTransform, ChainTransform, ScaleTransform, LinearTransform,
27+
ARDTransform, IdentityTransform, FunctionTransform
2728

2829
export NystromFact, nystrom
2930

src/trainable.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,6 @@ trainable(t::ChainTransform) = t.transforms
4242

4343
trainable(t::FunctionTransform) = (t.f,)
4444

45-
trainable(t::LowRankTransform) = (t.proj,)
45+
trainable(t::LinearTransform) = (t.A,)
4646

4747
trainable(t::ScaleTransform) = (t.s,)

src/transform/chaintransform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Chain a series of transform, here `t1` will be called first
55
```
66
t1 = ScaleTransform()
7-
t2 = LowRankTransform(rand(3,4))
7+
t2 = LinearTransform(rand(3,4))
88
ct = ChainTransform([t1,t2]) #t1 will be called first
99
ct == t2 ∘ t1
1010
```

src/transform/lineartransform.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""
2+
LinearTransform(A::AbstractMatrix)
3+
4+
Apply the linear transformation realised by the matrix `A`.
5+
6+
The second dimension of `A` must match the number of features of the target.
7+
8+
# Examples
9+
10+
```julia-repl
11+
julia> A = rand(10, 5)
12+
13+
julia> tr = LinearTransform(A)
14+
```
15+
"""
16+
struct LinearTransform{T<:AbstractMatrix{<:Real}} <: Transform
17+
A::T
18+
end
19+
20+
function set!(t::LinearTransform{<:AbstractMatrix{T}}, A::AbstractMatrix{T}) where {T<:Real}
21+
@assert size(t.A) == size(A) "Size of the given matrix $(size(A)) and the transformation matrix $(size(t.A)) are not the same"
22+
t.A .= A
23+
end
24+
25+
(t::LinearTransform)(x::Real) = vec(t.A * x)
26+
(t::LinearTransform)(x::AbstractVector{<:Real}) = t.A * x
27+
28+
function Base.map(t::LinearTransform, x::AbstractVector{<:Real})
29+
return ColVecs(t.A * x')
30+
end
31+
Base.map(t::LinearTransform, x::ColVecs) = ColVecs(t.A * x.X)
32+
Base.map(t::LinearTransform, x::RowVecs) = RowVecs(x.X * t.A')
33+
34+
function Base.show(io::IO, t::LinearTransform)
35+
print(io::IO, "Linear transform (size(A) = ", size(t.A), ")")
36+
end

src/transform/lowranktransform.jl

Lines changed: 0 additions & 33 deletions
This file was deleted.

src/transform/transform.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
export Transform, IdentityTransform, ScaleTransform, ARDTransform, LowRankTransform, FunctionTransform, ChainTransform
2-
31
include("scaletransform.jl")
42
include("ardtransform.jl")
5-
include("lowranktransform.jl")
3+
include("lineartransform.jl")
64
include("functiontransform.jl")
75
include("selecttransform.jl")
86
include("chaintransform.jl")

test/test_AD.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ end
4545
transform_AD(Val(Symbol($AD)),ScaleTransform(l),A)
4646
# ARD Transform
4747
transform_AD(Val(Symbol($AD)),ARDTransform(vl),A)
48-
# LowRankTransform
49-
transform_AD(Val(Symbol($AD)),LowRankTransform(rand(2,10)),A)
48+
# Linear transform
49+
transform_AD(Val(Symbol($AD)), LinearTransform(rand(2,10)),A)
5050
# Chain Transform
51-
# transform_AD(Val(Symbol($AD)),LowRankTransform,A)
51+
# transform_AD(Val(Symbol($AD)), LinearTransform, A)
5252
end
5353
end
5454
end

test/trainable.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@
4444
@test all(params(k) .== params(v, kc))
4545

4646
P = rand(3, 2)
47-
k = transform(km,LowRankTransform(P))
47+
k = transform(km, LinearTransform(P))
4848
@test all(params(k) .== params(P, km))
4949

50-
k = transform(km, LowRankTransform(P) ScaleTransform(s))
50+
k = transform(km, LinearTransform(P) ScaleTransform(s))
5151
@test all(params(k) .== params([s], P, km))
5252

5353
c = Chain(Dense(3, 2))

test/transform/chaintransform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
rng = MersenneTwister(123546)
33

44
P = rand(rng, 3, 2)
5-
tp = LowRankTransform(P)
5+
tp = LinearTransform(P)
66

77
f(x) = sin.(x)
88
tf = FunctionTransform(f)

0 commit comments

Comments
 (0)