Skip to content

Commit d88dcff

Browse files
committed
Created indirection from Base.map to _map for creating adjoints
1 parent 2ae0cd6 commit d88dcff

File tree

6 files changed

+11
-15
lines changed

6 files changed

+11
-15
lines changed

src/transform/ardtransform.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ dim(t::ARDTransform) = length(t.v)
2525
(t::ARDTransform)(x) = t.v .* x
2626

2727
Base.map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x
28-
Base.map(t::ARDTransform, x::ColVecs) = ColVecs(t.v .* x.X)
29-
Base.map(t::ARDTransform, x::RowVecs) = RowVecs(t.v' .* x.X)
28+
_map(t::ARDTransform, x::ColVecs) = ColVecs(t.v .* x.X)
29+
_map(t::ARDTransform, x::RowVecs) = RowVecs(t.v' .* x.X)
3030

3131
Base.isequal(t::ARDTransform, t2::ARDTransform) = isequal(t.v, t2.v)
3232

src/transform/functiontransform.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ end
1616
(t::FunctionTransform)(x) = t.f(x)
1717

1818
Base.map(t::FunctionTransform, x::AbstractVector{<:Real}) = map(t.f, x)
19-
Base.map(t::FunctionTransform, x::ColVecs) = ColVecs(mapslices(t.f, x.X; dims=1))
20-
Base.map(t::FunctionTransform, x::RowVecs) = RowVecs(mapslices(t.f, x.X; dims=2))
19+
_map(t::FunctionTransform, x::ColVecs) = ColVecs(mapslices(t.f, x.X; dims=1))
20+
_map(t::FunctionTransform, x::RowVecs) = RowVecs(mapslices(t.f, x.X; dims=2))
2121

2222
duplicate(t::FunctionTransform,f) = FunctionTransform(f)
2323

src/transform/lineartransform.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ end
2828
(t::LinearTransform)(x::AbstractVector{<:Real}) = t.A * x
2929

3030
Base.map(t::LinearTransform, x::AbstractVector{<:Real}) = ColVecs(t.A * x')
31-
Base.map(t::LinearTransform, x::ColVecs) = ColVecs(t.A * x.X)
32-
Base.map(t::LinearTransform, x::RowVecs) = RowVecs(x.X * t.A')
31+
_map(t::LinearTransform, x::ColVecs) = ColVecs(t.A * x.X)
32+
_map(t::LinearTransform, x::RowVecs) = RowVecs(x.X * t.A')
3333

3434
function Base.show(io::IO, t::LinearTransform)
3535
print(io::IO, "Linear transform (size(A) = ", size(t.A), ")")

src/transform/scaletransform.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ set!(t::ScaleTransform,ρ::Real) = t.s .= [ρ]
2020
(t::ScaleTransform)(x) = first(t.s) .* x
2121

2222
Base.map(t::ScaleTransform, x::AbstractVector{<:Real}) = first(t.s) .* x
23-
Base.map(t::ScaleTransform, x::ColVecs) = ColVecs(first(t.s) .* x.X)
24-
Base.map(t::ScaleTransform, x::RowVecs) = RowVecs(first(t.s) .* x.X)
23+
_map(t::ScaleTransform, x::ColVecs) = ColVecs(first(t.s) .* x.X)
24+
_map(t::ScaleTransform, x::RowVecs) = RowVecs(first(t.s) .* x.X)
2525

2626
Base.isequal(t::ScaleTransform,t2::ScaleTransform) = isequal(first(t.s),first(t2.s))
2727

src/transform/selecttransform.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ duplicate(t::SelectTransform,θ) = t
2525

2626
(t::SelectTransform)(x::AbstractVector) = view(x, t.select)
2727

28-
Base.map(t::SelectTransform, x::ColVecs) = ColVecs(view(x.X, t.select, :))
29-
Base.map(t::SelectTransform, x::RowVecs) = RowVecs(view(x.X, :, t.select))
28+
_map(t::SelectTransform, x::ColVecs) = ColVecs(view(x.X, t.select, :))
29+
_map(t::SelectTransform, x::RowVecs) = RowVecs(view(x.X, :, t.select))
3030

3131
Base.show(io::IO, t::SelectTransform) = print(io, "Select Transform (dims: ", t.select, ")")

src/transform/transform.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,8 @@ include("functiontransform.jl")
55
include("selecttransform.jl")
66
include("chaintransform.jl")
77

8-
"""
9-
apply(t::Transform, x; obsdim::Int=defaultobs)
108

11-
Apply the transform `t` vector-wise on the array `x`
12-
"""
13-
apply
9+
Base.map(t::Transform, x::Union{ColVecs, RowVecs}) = _map(t, x)
1410

1511
"""
1612
IdentityTransform()

0 commit comments

Comments
 (0)