Skip to content

Commit 57cbd5a

Browse files
feat: add svm training option
1 parent eca86f2 commit 57cbd5a

File tree

6 files changed

+148
-28
lines changed

6 files changed

+148
-28
lines changed

ext/RCLIBSVMExt.jl

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,71 @@
11
module RCLIBSVMExt
2-
using ReservoirComputing
2+
33
using LIBSVM
4+
using ReservoirComputing:
5+
SVMReadout, addreadout!, ReservoirChain
6+
import ReservoirComputing: train
47

5-
function ReservoirComputing.train(svr::LIBSVM.AbstractSVR,
6-
states::AbstractArray, target::AbstractArray)
7-
out_size = size(target, 1)
8-
output_matrix = []
8+
function train(svr::LIBSVM.AbstractSVR,
9+
states::AbstractArray, target::AbstractArray)
10+
@assert size(states, 2) == size(target, 2) "states and target must share columns."
11+
perm_states = permutedims(states)
12+
size_target = size(target, 1)
913

10-
if out_size == 1
11-
output_matrix = LIBSVM.fit!(svr, states', vec(target))
14+
if size_target == 1
15+
vec_target = vec(target)
16+
model = LIBSVM.fit!(svr, perm_states, vec_target)
17+
return model
1218
else
13-
for i in 1:out_size
14-
push!(output_matrix, LIBSVM.fit!(svr, states', target[i, :]))
19+
models = Vector{Any}(undef, size_target)
20+
for (idx, row_target) in enumerate(eachrow(target))
21+
models[idx] = LIBSVM.fit!(svr, perm_states, row_target)
1522
end
23+
return models
1624
end
17-
18-
return OutputLayer(svr, output_matrix, out_size, target[:, end])
1925
end
2026

21-
function ReservoirComputing.get_prediction(training_method::LIBSVM.AbstractSVR,
22-
output_layer::AbstractArray, x::AbstractArray)
23-
out = zeros(output_layer.out_size)
27+
_has_models(ps) = (ps isa NamedTuple) && (:models in propertynames(ps))
28+
29+
function (svmro::SVMReadout)(inp::AbstractArray, ps, st::NamedTuple)
30+
if !_has_models(ps)
31+
return inp, st
32+
end
33+
models = getfield(ps, :models)
34+
35+
vec_like = false
36+
if ndims(inp) == 1
37+
reshaped_inp = reshape(inp, 1, :)
38+
num_imp = 1
39+
vec_like = true
40+
elseif ndims(inp) == 2
41+
if size(inp, 2) == 1
42+
reshaped_inp = reshape(vec(inp), 1, :)
43+
num_inp = 1
44+
vec_like = true
45+
else
46+
reshaped_inp = permutedims(inp)
47+
num_imp = size(reshaped_inp, 1)
48+
end
49+
else
50+
throw(ArgumentError("SVMReadout expects 1D or 2D input; got size $(size(inp))"))
51+
end
2452

25-
for i in 1:(output_layer.out_size)
26-
x_new = reshape(x, 1, length(x))
27-
out[i] = LIBSVM.predict(output_layer.output_matrix[i], x_new)[1]
53+
if models isa AbstractVector
54+
out_data = Array{float(eltype(reshaped_inp))}(undef, svmro.out_dims, num_imp)
55+
@inbounds for i in 1:svmro.out_dims
56+
single_out = LIBSVM.predict(models[i], reshaped_inp)
57+
out_data[i, :] = single_out
58+
end
59+
else
60+
single_out = LIBSVM.predict(models, reshaped_inp)
61+
out_data = reshape(single_out, 1, :)
2862
end
2963

30-
return out
64+
if vec_like
65+
return vec(out_data), st
66+
else
67+
return out_data, st
68+
end
3169
end
3270

33-
end #module
71+
end # module

src/ReservoirComputing.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,35 +17,33 @@ using WeightInitializers: DeviceAgnostic, PartialFunction, Utils
1717
@reexport using WeightInitializers
1818
@reexport using LuxCore: setup, apply
1919

20-
2120
const BoolType = Union{StaticBool,Bool,Val{true},Val{false}}
2221
const InputType = Tuple{<:AbstractArray,Tuple{<:AbstractArray}}
2322
const IntegerType = Union{Integer,StaticInteger}
2423

25-
@compat(public, (create_states))
24+
#@compat(public, (create_states)) #do I need to add intialstates/parameters in compat?
2625

2726
#layers
2827
include("layers/basic.jl")
2928
include("layers/lux_layers.jl")
3029
include("layers/esn_cell.jl")
30+
include("layers/svmreadout.jl")
3131
#general
3232
include("states.jl")
3333
include("predict.jl")
3434
include("train.jl")
35-
#esn
35+
#initializers
3636
include("inits/inits_components.jl")
3737
include("inits/esn_inits.jl")
38+
#full models
3839
include("models/esn.jl")
3940
include("models/deepesn.jl")
4041
include("models/hybridesn.jl")
4142
#extensions
4243
include("extensions/reca.jl")
4344

44-
45-
46-
4745
export ESNCell, StatefulLayer, Readout, ReservoirChain, Collect, collectstates, train!, predict
48-
46+
export SVMReadout
4947
export Pad, Extend, NLAT1, NLAT2, NLAT3, PartialSquare, ExtendedSquare
5048
export StandardRidge
5149
export chebyshev_mapping, informed_init, logistic_mapping, minimal_init,

src/layers/basic.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ function collectstates(rc::AbstractLuxLayer, data::AbstractArray, ps, st::NamedT
195195
inp_tmp = inp
196196
state_vec = nothing
197197
for (name, layer) in pairs(rc.layers)
198+
if layer isa AbstractReservoirTrainableLayer
199+
break
200+
end
198201
inp_tmp, st_i = layer(inp_tmp, ps[name], newst[name])
199202
newst = merge(newst, (; name => st_i))
200203
if layer isa AbstractReservoirCollectionLayer

src/layers/svmreadout.jl

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
@concrete struct SVMReadout <: AbstractReservoirTrainableLayer
2+
in_dims <: IntegerType
3+
out_dims <: IntegerType
4+
include_collect <: StaticBool
5+
end
6+
7+
function Base.show(io::IO, ro::SVMReadout)
8+
print(io, "SVMReadout($(ro.in_dims) => $(ro.out_dims)")
9+
ic = known(getproperty(ro, Val(:include_collect)))
10+
ic === true && print(io, ", include_collect=true")
11+
return print(io, ")")
12+
end
13+
14+
SVMReadout(mapping::Pair{<:IntegerType,<:IntegerType}; kwargs...) =
15+
SVMReadout(first(mapping), last(mapping); kwargs...)
16+
17+
SVMReadout(in_dims::IntegerType, out_dims::IntegerType;
18+
include_collect::BoolType=True()) =
19+
SVMReadout(in_dims, out_dims, static(include_collect))
20+
21+
initialparameters(::AbstractRNG, ::SVMReadout) = NamedTuple()
22+
parameterlength(::SVMReadout) = 0
23+
statelength(::SVMReadout) = 0
24+
outputsize(ro::SVMReadout, _, ::AbstractRNG) = (ro.out_dims,)
25+
26+
# NOTE: forward for SVMReadout will be defined in the LIBSVM extension,
27+
# because it calls LIBSVM.predict.
28+
29+
_svmreadout_include_collect(ro::SVMReadout) = begin
30+
ic = known(getproperty(ro, Val(:include_collect)))
31+
ic === nothing ? false : ic
32+
end
33+
34+
function wrap_functions_in_chain_call(ro::SVMReadout)
35+
return _svmreadout_include_collect(ro) ? (Collect(), ro) : ro
36+
end
37+
38+
39+
_quote_keys(t) = Expr(:tuple, (QuoteNode(s) for s in t)...)
40+
41+
function _setmodels_rt(p::NamedTuple{K}, M) where {K}
42+
keys = K
43+
Kq = _quote_keys(keys)
44+
idx = findfirst(==(Symbol(:models)), keys)
45+
46+
terms = Any[]
47+
for i in 1:length(keys)
48+
push!(terms, (idx === i) ? :(M) : :(getfield(p, $i)))
49+
end
50+
51+
if idx === nothing
52+
newK = _quote_keys((keys..., :models))
53+
return :(NamedTuple{$newK}(($(terms...), M)))
54+
else
55+
return :(NamedTuple{$Kq}(($(terms...),)))
56+
end
57+
end
58+
59+
@generated function _addsvm(layers::NamedTuple{K}, ps::NamedTuple{K}, M) where {K}
60+
if length(K) == 0
61+
return :(NamedTuple())
62+
end
63+
tailK = Base.tail(K)
64+
Kq = _quote_keys(K)
65+
tailKq = _quote_keys(tailK)
66+
67+
head_val = :((getfield(layers, 1) isa SVMReadout)
68+
? _setmodels_rt(getfield(ps, 1), M)
69+
: getfield(ps, 1))
70+
71+
tail_call = :(_addsvm(NamedTuple{$tailKq}(Base.tail(layers)),
72+
NamedTuple{$tailKq}(Base.tail(ps)), M))
73+
74+
return :(NamedTuple{$Kq}(($head_val, Base.values($tail_call)...)))
75+
end
76+
77+
function addreadout!(rc::ReservoirChain, models, ps::NamedTuple, st::NamedTuple)
78+
@assert propertynames(rc.layers) == propertynames(ps)
79+
new_ps = _addsvm(rc.layers, ps, models)
80+
return new_ps, st
81+
end

src/predict.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ function predict(rc::AbstractLuxLayer, data::AbstractMatrix, ps, st)
6262
Y = similar(y1, size(y1, 1), T)
6363
Y[:, 1] .= y1
6464

65-
@inbounds @views for t in 2:T
65+
for t in 2:T
6666
yt, st = apply(rc, data[:, t], ps, st)
6767
Y[:, t] .= yt
6868
end

src/train.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ function train(sr::StandardRidge, states::AbstractArray, target_data::AbstractAr
9494
return output_layer
9595
end
9696

97-
_quote_keys(t) = Expr(:tuple, (QuoteNode(s) for s in t)...)
97+
#_quote_keys(t) = Expr(:tuple, (QuoteNode(s) for s in t)...)
9898

9999
@generated function _setweight_rt(p::NamedTuple{K}, W) where {K}
100100
keys = K

0 commit comments

Comments
 (0)