|
1 | 1 | module ModelingToolkitNeuralNets
|
2 | 2 |
|
3 | 3 | using ModelingToolkit: @parameters, @named, ODESystem, t_nounits
|
| 4 | +using IntervalSets: (..) |
4 | 5 | using ModelingToolkitStandardLibrary.Blocks: RealInputArray, RealOutputArray
|
5 | 6 | using Symbolics: Symbolics, @register_array_symbolic, @wrapped
|
6 | 7 | using LuxCore: stateless_apply
|
7 | 8 | using Lux: Lux
|
8 | 9 | using Random: Xoshiro
|
9 | 10 | using ComponentArrays: ComponentArray
|
10 | 11 |
|
11 |
| -export NeuralNetworkBlock, multi_layer_feed_forward |
| 12 | +export NeuralNetworkBlock, SymbolicNeuralNetwork, multi_layer_feed_forward, get_network |
12 | 13 |
|
13 | 14 | include("utils.jl")
|
14 | 15 |
|
@@ -56,4 +57,69 @@ function lazyconvert(T, x::Symbolics.Arr)
|
56 | 57 | Symbolics.array_term(convert, T, x, size = size(x))
|
57 | 58 | end
|
58 | 59 |
|
| 60 | +""" |
| 61 | + SymbolicNeuralNetwork(; n_input = 1, n_output = 1, |
| 62 | + chain = multi_layer_feed_forward(n_input, n_output), |
| 63 | + rng = Xoshiro(0), |
| 64 | + init_params = Lux.initialparameters(rng, chain), |
| 65 | + eltype = Float64) |
| 66 | +
|
| 67 | +Create symbolic parameter for a neural network and one for its parameters. |
| 68 | +Example: |
| 69 | +
|
| 70 | +``` |
| 71 | +chain = multi_layer_feed_forward(2, 2) |
| 72 | +NN, p = SymbolicNeuralNetwork(; chain, n_input=2, n_output=2, rng = StableRNG(42)) |
| 73 | +``` |
| 74 | +
|
| 75 | +The NN and p are symbolic parameters that can be used later as part of a system. |
| 76 | +To get the predictions of the neural network, use |
| 77 | +
|
| 78 | +``` |
| 79 | +pred ~ NN(input, p) |
| 80 | +``` |
| 81 | +
|
| 82 | +where `pred` and `input` are a symbolic vector variable with the lengths `n_output` and `n_input`. |
| 83 | +
|
| 84 | +To use this outside of an equation, you can get the default values for the symbols and make a similar call |
| 85 | +
|
| 86 | +``` |
| 87 | +defaults(sys)[sys.NN](input, nn_p) |
| 88 | +``` |
| 89 | +
|
| 90 | +where `sys` is a system (e.g. `ODESystem`) that contains `NN`, `input` is a vector of `n_input` length and |
| 91 | +`nn_p` is a vector representing parameter values for the neural network. |
| 92 | +
|
| 93 | +To get the underlying Lux model you can use `get_network(defaults(sys)[sys.NN])` or |
| 94 | +""" |
| 95 | +function SymbolicNeuralNetwork(; n_input = 1, n_output = 1, |
| 96 | + chain = multi_layer_feed_forward(n_input, n_output), |
| 97 | + rng = Xoshiro(0), |
| 98 | + init_params = Lux.initialparameters(rng, chain), |
| 99 | + eltype = Float64) |
| 100 | + ca = ComponentArray{eltype}(init_params) |
| 101 | + wrapper = StatelessApplyWrapper(chain, typeof(ca)) |
| 102 | + |
| 103 | + @parameters p[1:length(ca)] = Vector(ca) |
| 104 | + @parameters (NN::typeof(wrapper))(..)[1:n_output] = wrapper |
| 105 | + |
| 106 | + return NN, p |
| 107 | +end |
| 108 | + |
| 109 | +struct StatelessApplyWrapper{NN} |
| 110 | + lux_model::NN |
| 111 | + T::DataType |
| 112 | +end |
| 113 | + |
| 114 | +function (wrapper::StatelessApplyWrapper)(input::AbstractArray, nn_p::AbstractVector) |
| 115 | + stateless_apply(get_network(wrapper), input, convert(wrapper.T, nn_p)) |
| 116 | +end |
| 117 | + |
| 118 | +function Base.show(io::IO, m::MIME"text/plain", wrapper::StatelessApplyWrapper) |
| 119 | + printstyled(io, "LuxCore.stateless_apply wrapper for:\n", color = :gray) |
| 120 | + show(io, m, get_network(wrapper)) |
| 121 | +end |
| 122 | + |
| 123 | +get_network(wrapper::StatelessApplyWrapper) = wrapper.lux_model |
| 124 | + |
59 | 125 | end
|
0 commit comments