Skip to content

Commit c4d214d

Browse files
fix: lower bound for JET
1 parent 3617293 commit c4d214d

File tree

3 files changed

+36
-15
lines changed

3 files changed

+36
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ CellularAutomata = "0.0.6"
3434
Compat = "4.16.0"
3535
ConcreteStructs = "0.2.3"
3636
DifferentialEquations = "7.16.1"
37-
JET = "0.10.9"
37+
JET = "0.9, 0.10.10"
3838
LIBSVM = "0.8"
3939
LinearAlgebra = "1.10"
4040
LuxCore = "1.3.0"

src/layers/lux_layers.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ end
8282
(c::ReservoirChain)(x, ps, st::NamedTuple) = applychain(c.layers, x, ps, st)
8383

8484
@generated function applychain(
85-
layers::NamedTuple{fields}, x, ps, st::NamedTuple{fields}) where {fields}
85+
layers::NamedTuple{fields}, x, ps, st::NamedTuple{fields}
86+
) where {fields}
87+
@assert isa(fields, NTuple{<:Any, Symbol})
8688
N = length(fields)
8789
x_symbols = vcat([:x], [gensym() for _ in 1:N])
8890
st_symbols = [gensym() for _ in 1:N]

src/models/esn_generics.jl

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,42 @@ _wrap_layers(xs::Tuple) = map(_wrap_layer, xs)
1212
return inp, tuple(new_st_parts...)
1313
end
1414

15-
function _asvec(comp, n_layers::Integer)
16-
if comp === ()
17-
return ntuple(_ -> nothing, n_layers)
18-
elseif comp isa Tuple || comp isa AbstractVector
19-
len = length(comp)
20-
if len == n_layers
21-
return Tuple(comp)
22-
elseif len == 1
23-
return ntuple(_ -> comp[1], n_layers)
24-
else
25-
error("Expected length $n_layers or 1, got $len")
26-
end
15+
@inline function _fillvec(x, n::Integer)
16+
v = Vector{typeof(x)}(undef, n)
17+
@inbounds @simd for i in 1:n
18+
v[i] = x
19+
end
20+
return v
21+
end
22+
23+
@inline _asvec(::Tuple{}, n::Integer) = _fillvec(nothing, n)
24+
25+
@inline function _asvec(comp::Tuple, n::Integer)
26+
len = length(comp)
27+
if len == n
28+
return collect(comp)
29+
elseif len == 1
30+
return _fillvec(comp[1], n)
31+
else
32+
error("Expected length $n or 1, got $len")
33+
end
34+
end
35+
36+
@inline function _asvec(comp::AbstractVector, n::Integer)
37+
len = length(comp)
38+
if len == n
39+
return collect(comp)
40+
elseif len == 1
41+
return _fillvec(comp[1], n)
2742
else
28-
return ntuple(_ -> comp, n_layers)
43+
error("Expected length $n or 1, got $len")
2944
end
3045
end
3146

47+
@inline _asvec(::Nothing, n::Integer) = _fillvec(nothing, n)
48+
49+
@inline _asvec(comp, n::Integer) = _fillvec(comp, n)
50+
3251
@inline _asvec(x) = (ndims(x) == 2 ? vec(x) : x)
3352

3453
function _coerce_layer_mods(x)

0 commit comments

Comments
 (0)