Skip to content

Commit 4367140

Browse files
committed
Corrected mistakes in transform
1 parent 6f467a3 commit 4367140

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

src/transform/lowranktransform.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,13 @@ Base.size(tr::LowRankTransform,i::Int) = size(tr.proj,i)
66
Base.size(tr::LowRankTransform) = size(tr.proj)
77

88
function transform(t::LowRankTransform,X::AbstractMatrix{<:Real},obsdim::Int)
9-
@boundscheck if size(t,2) != size(X,1)
10-
throw(DimensionMismatch("The projection matrix has size $(size(t)) and cannot be used on X with dimensions $(size(X))"))
11-
end
12-
_transform(t,X,obsdim)
9+
@boundscheck size(t,2) != size(X,feature_dim(obsdim)) ?
10+
throw(DimensionMismatch("The projection matrix has size $(size(t)) and cannot be used on X with dimensions $(size(X))")) : nothing
11+
@inbounds _transform(t,X,obsdim)
1312
end
14-
_transform(t::LowRankTransform,x::AbstractVector{<:Real}) = t.proj * x
15-
_transform(t::LowRankTransform,X::AbstractMatrix{<:Real},obsdim::Int) = t.proj * X
13+
function transform(t::LowRankTransform,x::AbstractVector{<:Real})
14+
@assert size(t,2) == length(x) "Vector has wrong dimensions"
15+
t.proj*X
16+
end
17+
18+
_transform(t::LowRankTransform,X::AbstractVecOrMat{<:Real},obsdim::Int) = obsdim == 2 ? t.proj * X : X * t.proj'

src/transform/transform.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@ Base.:∘(tc::TransformChain,t::Transform) = TransformChain(vcat(t,tc.transforms
3131
"""
3232
FunctionTransform
3333
34-
Take a function `f` as an argument which is going to act on the matrix as a whole.
35-
Make sure that `f` is supposed to act on a matrix by eventually using broadcasting
34+
Take a function `f` as an argument which is going to act on each vector individually.
35+
Make sure that `f` is supposed to act on a vector by eventually using broadcasting
3636
For example `f(x)=sin(x)` -> `f(x)=sin.(x)`
3737
"""
3838
struct FunctionTransform{F} <: Transform
3939
f::F
4040
end
4141

42-
transform(t::FunctionTransform,X::T,obsdim::Int=1) where {T} = t.f(X)
42+
transform(t::FunctionTransform,X::T,obsdim::Int=1) where {T} = mapslices(t.f,X,obsdim)
4343

4444

4545
struct IdentityTransform <: Transform end

0 commit comments

Comments
 (0)