Skip to content

Commit 774aba9

Browse files
committed
refactor: make the lux model a parameter in the NNBlock
needs JuliaSymbolics/Symbolics.jl#1508
1 parent 7466069 commit 774aba9

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/ModelingToolkitNeuralNets.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,17 @@ function NeuralNetworkBlock(; n_input = 1, n_output = 1,
3232

3333
@parameters p[1:length(ca)] = Vector(ca)
3434
@parameters T::typeof(typeof(ca))=typeof(ca) [tunable = false]
35+
@parameters lux_model::typeof(chain) = chain
3536

3637
@named input = RealInputArray(nin = n_input)
3738
@named output = RealOutputArray(nout = n_output)
3839

39-
out = stateless_apply(chain, input.u, lazyconvert(T, p))
40+
out = stateless_apply(lux_model, input.u, lazyconvert(T, p))
4041

4142
eqs = [output.u ~ out]
4243

4344
ude_comp = ODESystem(
44-
eqs, t_nounits, [], [p, T]; systems = [input, output], name)
45+
eqs, t_nounits, [], [lux_model, p, T]; systems = [input, output], name)
4546
return ude_comp
4647
end
4748

0 commit comments

Comments
 (0)