Skip to content

Commit ea06b60

Browse files
committed
Uniformized the obsdim argument over the transforms
1 parent 4367140 commit ea06b60

File tree

4 files changed

+19
-22
lines changed

4 files changed

+19
-22
lines changed

src/transform/functiontransform.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""
2+
FunctionTransform
3+
4+
Take a function `f` as an argument which is going to act on each vector individually.
5+
Make sure that `f` is supposed to act on a vector by eventually using broadcasting
6+
For example `f(x)=sin(x)` -> `f(x)=sin.(x)`
7+
"""
8+
struct FunctionTransform{F} <: Transform
9+
f::F
10+
end
11+
12+
transform(t::FunctionTransform,X::T,obsdim::Int=defaultobs) where {T} = mapslices(t.f,X,dims=obsdim)

src/transform/lowranktransform.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ end
55
Base.size(tr::LowRankTransform,i::Int) = size(tr.proj,i)
66
Base.size(tr::LowRankTransform) = size(tr.proj)
77

8-
function transform(t::LowRankTransform,X::AbstractMatrix{<:Real},obsdim::Int)
8+
function transform(t::LowRankTransform,X::AbstractMatrix{<:Real},obsdim::Int=defaultobs)
99
@boundscheck size(t,2) != size(X,feature_dim(obsdim)) ?
1010
throw(DimensionMismatch("The projection matrix has size $(size(t)) and cannot be used on X with dimensions $(size(X))")) : nothing
1111
@inbounds _transform(t,X,obsdim)
@@ -15,4 +15,4 @@ function transform(t::LowRankTransform,x::AbstractVector{<:Real})
1515
t.proj*X
1616
end
1717

18-
_transform(t::LowRankTransform,X::AbstractVecOrMat{<:Real},obsdim::Int) = obsdim == 2 ? t.proj * X : X * t.proj'
18+
_transform(t::LowRankTransform,X::AbstractVecOrMat{<:Real},obsdim::Int=defaultobs) = obsdim == 2 ? t.proj * X : X * t.proj'

src/transform/scaletransform.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ function transform(t::ScaleTransform{<:AbstractVector{<:Real}},X::AbstractMatrix
2727
_transform(t,X,obsdim)
2828
end
2929
_transform(t::ScaleTransform{<:AbstractVector{<:Real}},x::AbstractVector{<:Real}) = t.s .* x
30-
_transform(t::ScaleTransform{<:AbstractVector{<:Real}},X::AbstractMatrix{<:Real},obsdim::Int) = obsdim == 1 ? t.s'.*X : t.s .* X
30+
_transform(t::ScaleTransform{<:AbstractVector{<:Real}},X::AbstractMatrix{<:Real},obsdim::Int=defaultobs) = obsdim == 1 ? t.s'.*X : t.s .* X
3131

32-
transform(t::ScaleTransform{<:Real},x::AbstractVecOrMat,obsdim::Int) = transform(t,x)
33-
transform(t::ScaleTransform{<:Real},x::AbstractVecOrMat) = t.s .* x
32+
transform(t::ScaleTransform{<:Real},x::AbstractVecOrMat,obsdim::Int=defaultobs) = t.s .* x

src/transform/transform.jl

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ abstract type Transform end
55

66
include("scaletransform.jl")
77
include("lowranktransform.jl")
8+
include("functiontransform.jl")
89

910
struct TransformChain <: Transform
1011
transforms::Vector{Transform}
@@ -16,7 +17,7 @@ function TransformChain(v::AbstractVector{<:Transform})
1617
TransformChain(v)
1718
end
1819

19-
function transform(t::TransformChain,X::T,obsdim::Int=1) where {T}
20+
function transform(t::TransformChain,X::T,obsdim::Int=defaultobs) where {T}
2021
Xtr = copy(X)
2122
for tr in t.transforms
2223
Xtr = transform(tr,Xtr,obsdim)
@@ -28,24 +29,9 @@ Base.:∘(t₁::Transform,t₂::Transform) = TransformChain([t₂,t₁])
2829
Base.:(t::Transform,tc::TransformChain) = TransformChain(vcat(tc.transforms,t))
2930
Base.:(tc::TransformChain,t::Transform) = TransformChain(vcat(t,tc.transforms))
3031

31-
"""
32-
FunctionTransform
33-
34-
Take a function `f` as an argument which is going to act on each vector individually.
35-
Make sure that `f` is supposed to act on a vector by eventually using broadcasting
36-
For example `f(x)=sin(x)` -> `f(x)=sin.(x)`
37-
"""
38-
struct FunctionTransform{F} <: Transform
39-
f::F
40-
end
41-
42-
transform(t::FunctionTransform,X::T,obsdim::Int=1) where {T} = mapslices(t.f,X,obsdim)
43-
44-
4532
struct IdentityTransform <: Transform end
4633

47-
transform(t::IdentityTransform,x::AbstractArray,obsdim::Int) = transform(t,x)
48-
transform(t::IdentityTransform,x::AbstractArray) = return x
34+
transform(t::IdentityTransform,x::AbstractArray,obsdim::Int=defaultobs) = x
4935

5036
### TODO Maybe defining adjoints could help but so far it's not working
5137

0 commit comments

Comments
 (0)