Skip to content

Commit 2f1229d

Browse files
committed
Added trainable for all transforms
1 parent d08f329 commit 2f1229d

File tree

7 files changed

+15
-13
lines changed

7 files changed

+15
-13
lines changed

src/transform/ardtransform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ function set!(t::ARDTransform{T},ρ::AbstractVector{T}) where {T<:Real}
2525
t.v .= ρ
2626
end
2727

28-
params(t::ARDTransform) = t.v
28+
trainable(t::ARDTransform) = (t.v,)
2929
dim(t::ARDTransform) = length(t.v)
3030

3131
function apply(t::ARDTransform,X::AbstractMatrix{<:Real};obsdim::Int = defaultobs)

src/transform/chaintransform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ function apply(t::ChainTransform,X::T;obsdim::Int=defaultobs) where {T}
3232
end
3333

3434
set!(t::ChainTransform,θ) = set!.(t.transforms,θ)
35-
params(t::ChainTransform) = (params.(t.transforms))
35+
trainable(t::ChainTransform) = t.transforms
3636
duplicate(t::ChainTransform,θ) = ChainTransform(duplicate.(t.transforms,θ))
3737

3838

src/transform/functiontransform.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ FunctionTransform
44
f(x) = abs.(x)
55
tr = FunctionTransform(f)
66
```
7-
Take a function `f` as an argument which is going to act on each vector individually.
7+
Take a function or object `f` as an argument which is going to act on each vector individually.
88
Make sure that `f` is supposed to act on a vector by eventually using broadcasting
99
For example `f(x)=sin(x)` -> `f(x)=sin.(x)`
1010
"""
@@ -15,4 +15,4 @@ end
1515
apply(t::FunctionTransform, X::T; obsdim::Int = defaultobs) where {T} = mapslices(t.f, X, dims = feature_dim(obsdim))
1616

1717
duplicate(t::FunctionTransform,f) = FunctionTransform(f)
18-
params(t::FunctionTransform) = t.f
18+
trainable(t::FunctionTransform) = (t.f,)

src/transform/lowranktransform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function set!(t::LowRankTransform{<:AbstractMatrix{T}},M::AbstractMatrix{T}) whe
1616
t.proj .= M
1717
end
1818

19-
params(t::LowRankTransform) = t.proj
19+
trainable(t::LowRankTransform) = (t.proj,)
2020

2121
Base.size(tr::LowRankTransform,i::Int) = size(tr.proj,i)
2222
Base.size(tr::LowRankTransform) = size(tr.proj) # TODO Add test

src/transform/scaletransform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function ScaleTransform(s::T=1.0) where {T<:Real}
1616
end
1717

1818
set!(t::ScaleTransform::Real) = t.s .= [ρ]
19-
params(t::ScaleTransform) = t.s
19+
trainable(t::ScaleTransform) = (t.s,)
2020
dim(str::ScaleTransform) = 1
2121

2222
apply(t::ScaleTransform,x::AbstractVecOrMat;obsdim::Int=defaultobs) = first(t.s) * x

src/transform/selecttransform.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ end
2222

2323
set!(t::SelectTransform{<:AbstractVector{T}},dims::AbstractVector{T}) where {T<:Int} = t.select .= dims
2424

25-
params(t::SelectTransform) = t.select
26-
2725
duplicate(t::SelectTransform,θ) = t
2826

2927

src/transform/transform.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,19 @@ include("functiontransform.jl")
77
include("selecttransform.jl")
88
include("chaintransform.jl")
99

10+
"""
11+
`apply(t::Transform, x; obsdim::Int=defaultobs)`
12+
Apply the transform `t` per slice on the array `x`
13+
"""
14+
apply
15+
1016
"""
1117
IdentityTransform
1218
Return exactly the input
1319
"""
1420
struct IdentityTransform <: Transform end
1521

16-
params(t::IdentityTransform) = nothing
17-
18-
apply(t::IdentityTransform, x; obsdim::Int=defaultobs) = x #TODO add test
22+
apply(t::IdentityTransform, x; obsdim::Int=defaultobs) = x
1923

2024
### TODO Maybe defining adjoints could help but so far it's not working
2125

@@ -32,9 +36,9 @@ apply(t::IdentityTransform, x; obsdim::Int=defaultobs) = x #TODO add test
3236

3337
# @adjoint transform(t::ScaleTransform{<:AbstractVector{<:Real}},x::AbstractVector{<:Real}) = transform(t,x),Δ->(ScaleTransform(nothing),t.s.*Δ)
3438
#
35-
# @adjoint transform(t::ScaleTransform{<:AbstractVector{<:Real}},X::AbstractMatrix{<:Real},obsdim::Int) = transform(t,X,obsdim),Δ->begin
39+
# @adjoint transform(t::ARDTransform{<:Real},X::AbstractMatrix{<:Real},obsdim::Int) = transform(t,X,obsdim),Δ->begin
3640
# @show Δ,size(Δ);
37-
# return (obsdim == 1 ? ScaleTransform()Δ'.*X : ScaleTransform()Δ.*X,transform(t,Δ,obsdim),nothing)
41+
# return (obsdim == 1 ? ARD()Δ'.*X : ScaleTransform()Δ.*X,transform(t,Δ,obsdim),nothing)
3842
# end
3943
#
4044
# @adjoint transform(t::ScaleTransform{T},x::AbstractVecOrMat,obsdim::Int) where {T<:Real} = transform(t,x), Δ->(ScaleTransform(one(T)),t.s.*Δ,nothing)

0 commit comments

Comments
 (0)