11module ModelingToolkitNeuralNets
22
33using ModelingToolkit: @parameters , @named , ODESystem, t_nounits
4+ using IntervalSets: var".."
45using ModelingToolkitStandardLibrary. Blocks: RealInputArray, RealOutputArray
56using Symbolics: Symbolics, @register_array_symbolic , @wrapped
67using LuxCore: stateless_apply
78using Lux: Lux
89using Random: Xoshiro
910using ComponentArrays: ComponentArray
1011
11- export NeuralNetworkBlock, multi_layer_feed_forward
12+ export NeuralNetworkBlock, SymbolicNeuralNetwork, multi_layer_feed_forward, get_network
1213
1314include (" utils.jl" )
1415
@@ -32,16 +33,17 @@ function NeuralNetworkBlock(; n_input = 1, n_output = 1,
3233
3334 @parameters p[1 : length (ca)] = Vector (ca)
3435 @parameters T:: typeof (typeof (ca))= typeof (ca) [tunable = false ]
36+ @parameters lux_model:: typeof (chain) = chain
3537
3638 @named input = RealInputArray (nin = n_input)
3739 @named output = RealOutputArray (nout = n_output)
3840
39- out = stateless_apply (chain , input. u, lazyconvert (T, p))
41+ out = stateless_apply (lux_model , input. u, lazyconvert (T, p))
4042
4143 eqs = [output. u ~ out]
4244
4345 ude_comp = ODESystem (
44- eqs, t_nounits, [], [p, T]; systems = [input, output], name)
46+ eqs, t_nounits, [], [lux_model, p, T]; systems = [input, output], name)
4547 return ude_comp
4648end
4749
@@ -55,4 +57,74 @@ function lazyconvert(T, x::Symbolics.Arr)
5557 Symbolics. array_term (convert, T, x, size = size (x))
5658end
5759
60+ """
61+ SymbolicNeuralNetwork(; n_input = 1, n_output = 1,
62+ chain = multi_layer_feed_forward(n_input, n_output),
63+ rng = Xoshiro(0),
64+ init_params = Lux.initialparameters(rng, chain),
65+ nn_name = :NN,
66+ nn_p_name = :p,
67+ eltype = Float64)
68+
69+ Create symbolic parameter for a neural network and one for its parameters.
70+ Example:
71+
72+ ```
73+ chain = multi_layer_feed_forward(2, 2)
74+ NN, p = SymbolicNeuralNetwork(; chain, n_input=2, n_output=2, rng = StableRNG(42))
75+ ```
76+
77+ The NN and p are symbolic parameters that can be used later as part of a system.
78+ To change the name of the symbolic variables, use `nn_name` and `nn_p_name`.
79+ To get the predictions of the neural network, use
80+
81+ ```
82+ pred ~ NN(input, p)
83+ ```
84+
85+ where `pred` and `input` are a symbolic vector variable with the lengths `n_output` and `n_input`.
86+
87+ To use this outside of an equation, you can get the default values for the symbols and make a similar call
88+
89+ ```
90+ defaults(sys)[sys.NN](input, nn_p)
91+ ```
92+
93+ where `sys` is a system (e.g. `ODESystem`) that contains `NN`, `input` is a vector of `n_input` length and
94+ `nn_p` is a vector representing parameter values for the neural network.
95+
96+ To get the underlying Lux model you can use `get_network(defaults(sys)[sys.NN])` or
97+ """
98+ function SymbolicNeuralNetwork (; n_input = 1 , n_output = 1 ,
99+ chain = multi_layer_feed_forward (n_input, n_output),
100+ rng = Xoshiro (0 ),
101+ init_params = Lux. initialparameters (rng, chain),
102+ nn_name = :NN ,
103+ nn_p_name = :p ,
104+ eltype = Float64)
105+ ca = ComponentArray {eltype} (init_params)
106+ wrapper = StatelessApplyWrapper (chain, typeof (ca))
107+
108+ p = @parameters $ (nn_p_name)[1 : length (ca)] = Vector (ca)
109+ NN = @parameters ($ (nn_name):: typeof (wrapper))(.. )[1 : n_output] = wrapper
110+
111+ return only (NN), only (p)
112+ end
113+
114+ struct StatelessApplyWrapper{NN}
115+ lux_model:: NN
116+ T:: DataType
117+ end
118+
119+ function (wrapper:: StatelessApplyWrapper )(input:: AbstractArray , nn_p:: AbstractVector )
120+ stateless_apply (get_network (wrapper), input, convert (wrapper. T, nn_p))
121+ end
122+
123+ function Base. show (io:: IO , m:: MIME"text/plain" , wrapper:: StatelessApplyWrapper )
124+ printstyled (io, " LuxCore.stateless_apply wrapper for:\n " , color = :gray )
125+ show (io, m, get_network (wrapper))
126+ end
127+
128+ get_network (wrapper:: StatelessApplyWrapper ) = wrapper. lux_model
129+
58130end
0 commit comments