Skip to content

Commit 27c1213

Browse files
tests: add JET, small fixes
1 parent 57cbd5a commit 27c1213

File tree

15 files changed

+221
-370
lines changed

15 files changed

+221
-370
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ CellularAutomata = "0.0.6"
3636
Compat = "4.16.0"
3737
ConcreteStructs = "0.2.3"
3838
DifferentialEquations = "7.16.1"
39+
JET = "0.9.20"
3940
LIBSVM = "0.8"
4041
LinearAlgebra = "1.10"
4142
LuxCore = "1.3.0"
@@ -54,6 +55,7 @@ julia = "1.10"
5455
[extras]
5556
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
5657
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
58+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
5759
LIBSVM = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b"
5860
MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"
5961
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
@@ -62,4 +64,4 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
6264
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6365

6466
[targets]
65-
test = ["Aqua", "Test", "SafeTestsets", "DifferentialEquations", "MLJLinearModels", "LIBSVM", "Statistics", "SparseArrays", "CellularAutomata"]
67+
test = ["Aqua", "Test", "SafeTestsets", "DifferentialEquations", "MLJLinearModels", "LIBSVM", "Statistics", "SparseArrays", "CellularAutomata", "JET"]

ext/RCLIBSVMExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ function (svmro::SVMReadout)(inp::AbstractArray, ps, st::NamedTuple)
5252

5353
if models isa AbstractVector
5454
out_data = Array{float(eltype(reshaped_inp))}(undef, svmro.out_dims, num_imp)
55-
@inbounds for i in 1:svmro.out_dims
56-
single_out = LIBSVM.predict(models[i], reshaped_inp)
57-
out_data[i, :] = single_out
55+
for (idx, model) in enumerate(models)
56+
single_out = LIBSVM.predict(models[idx], reshaped_inp)
57+
out_data[idx, :] = single_out
5858
end
5959
else
6060
single_out = LIBSVM.predict(models, reshaped_inp)

src/ReservoirComputing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using Adapt: adapt
44
using ArrayInterface: ArrayInterface
55
using Compat: @compat
66
using ConcreteStructs: @concrete
7-
using LinearAlgebra: eigvals, mul!, I, qr, Diagonal
7+
using LinearAlgebra: eigvals, mul!, I, qr, Diagonal, diag
88
using LuxCore: AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer,
99
setup, apply, replicate
1010
import LuxCore: initialparameters, initialstates, statelength, outputsize

src/inits/esn_inits.jl

Lines changed: 164 additions & 162 deletions
Large diffs are not rendered by default.

src/layers/basic.jl

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,33 +52,27 @@ before this layer (logically inserting a [`Collect()`](@ref) right before it).
5252
activation
5353
in_dims <: IntegerType
5454
out_dims <: IntegerType
55+
init_weight
56+
init_bias
5557
use_bias <: StaticBool
5658
include_collect <: StaticBool
5759
end
5860

59-
function Base.show(io::IO, ro::Readout)
60-
print(io, "Readout($(ro.in_dims) => $(ro.out_dims)")
61-
(ro.activation == identity) || print(io, ", $(ro.activation)")
62-
has_bias(ro) || print(io, ", use_bias=false")
63-
ic = known(getproperty(ro, Val(:include_collect)))
64-
ic === true && print(io, ", include_collect=true")
65-
return print(io, ")")
66-
end
67-
6861
function Readout(mapping::Pair{<:IntegerType,<:IntegerType}, activation=identity; kwargs...)
6962
return Readout(first(mapping), last(mapping), activation; kwargs...)
7063
end
7164

7265
function Readout(in_dims::IntegerType, out_dims::IntegerType, activation=identity;
73-
include_collect::BoolType=True(), use_bias::BoolType=False())
74-
return Readout(activation, in_dims, out_dims, static(use_bias), static(include_collect))
66+
init_weight=rand32, init_bias=rand32, include_collect::BoolType=True(),
67+
use_bias::BoolType=False())
68+
return Readout(activation, in_dims, out_dims, init_weight, init_bias, static(use_bias), static(include_collect))
7569
end
7670

7771
function initialparameters(rng::AbstractRNG, ro::Readout)
78-
weight = rand(rng, Float32, ro.out_dims, ro.in_dims)
72+
weight = ro.init_weight(rng, ro.out_dims, ro.in_dims)
7973

8074
if has_bias(ro)
81-
return (; weight, bias=rand(rng, Float32, ro.out_dims))
75+
return (; weight, bias=ro.init_bias(rng, Float32, ro.out_dims))
8276
else
8377
return (; weight)
8478
end
@@ -98,6 +92,15 @@ function (ro::Readout)(inp::AbstractArray, ps, st::NamedTuple)
9892
return output, st
9993
end
10094

95+
function Base.show(io::IO, ro::Readout)
96+
print(io, "Readout($(ro.in_dims) => $(ro.out_dims)")
97+
(ro.activation == identity) || print(io, ", $(ro.activation)")
98+
has_bias(ro) || print(io, ", use_bias=false")
99+
ic = known(getproperty(ro, Val(:include_collect)))
100+
ic === true && print(io, ", include_collect=true")
101+
return print(io, ")")
102+
end
103+
101104
@doc raw"""
102105
Collect()
103106
@@ -188,7 +191,7 @@ that time step.
188191
- `st`: Updated model states.
189192
190193
"""
191-
function collectstates(rc::AbstractLuxLayer, data::AbstractArray, ps, st::NamedTuple)
194+
function collectstates(rc::AbstractLuxLayer, data::AbstractMatrix, ps, st::NamedTuple)
192195
newst = st
193196
collected = Any[]
194197
for inp in eachcol(data)
@@ -209,3 +212,7 @@ function collectstates(rc::AbstractLuxLayer, data::AbstractArray, ps, st::NamedT
209212
states = eltype(data).(reduce(hcat, collected))
210213
return states, newst
211214
end
215+
216+
function collectstates(rc::AbstractLuxLayer, data::AbstractVector, ps, st::NamedTuple)
217+
return collectstates(rc, reshape(data, :, 1), ps, st)
218+
end

src/layers/esn_cell.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ Created by `initialstates(rng, esn)`:
7575
use_bias <: StaticBool
7676
end
7777

78-
function ESNCell((in_dims, out_dims)::Pair{<:Int,<:Int}, activation=tanh;
78+
function ESNCell((in_dims, out_dims)::Pair{<:IntegerType,<:IntegerType}, activation=tanh;
7979
use_bias::BoolType=False(), init_bias=zeros32, init_reservoir=rand_sparse,
8080
init_input=weighted_init, init_state=randn32, leak_coefficient=1.0)
8181
return ESNCell(activation, in_dims, out_dims, init_bias, init_reservoir,

src/layers/lux_layers.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ function named_tuple_layers(layers::Vararg{AbstractLuxLayer,N}) where {N}
124124
return NamedTuple{ntuple(i -> Symbol(:layer_, i), N)}(layers)
125125
end
126126

127+
function index_namedtuple(nt::NamedTuple{fields}, idxs::AbstractArray) where {fields}
128+
return NamedTuple{fields[idxs]}(values(nt)[idxs])
129+
end
130+
127131
# from Lux extended_ops
128132
const KnownSymbolType{v} = Union{Val{v},StaticSymbol{v}}
129133

@@ -132,6 +136,10 @@ function has_bias(l::AbstractLuxLayer)
132136
return ifelse(res === nothing, false, res)
133137
end
134138

135-
function getproperty(x, ::KnownSymbolType{v}) where {v}
136-
return v Base.propertynames(x) ? Base.getproperty(x, v) : nothing
139+
@generated function getproperty(x::X, ::KnownSymbolType{v}) where {X,v}
140+
if hasfield(X, v)
141+
return :(getfield(x, v))
142+
else
143+
return :(nothing)
144+
end
137145
end

src/states.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11

2-
function _apply_tomatrix(res_states, x_old::AbstractMatrix)
3-
results = res_states.(eachcol(x_old))
4-
return hcat(results...)
2+
@inline function _apply_tomatrix(states_mod::F, states::AbstractMatrix) where {F<:Function}
3+
cols = axes(states, 2)
4+
states_1 = states_mod(states[:, first(cols)])
5+
new_states = similar(states_1, length(states_1), length(cols))
6+
new_states[:, 1] .= states_1
7+
for (k, j) in enumerate(cols)
8+
new_states[:, k] .= states_mod(@view states[:, j])
9+
end
10+
return new_states
511
end
612

713

test/esn/deepesn.jl

Lines changed: 0 additions & 25 deletions
This file was deleted.

test/esn/test_drivers.jl

Lines changed: 0 additions & 41 deletions
This file was deleted.

0 commit comments

Comments
 (0)