Skip to content

Commit 39d16da

Browse files
committed
feat: add SymbolicNeuralNetwork as a callable parameter interface
1 parent 774aba9 commit 39d16da

File tree

2 files changed

+65
-1
lines changed

2 files changed

+65
-1
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "1.6.1"
55

66
[deps]
77
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
8+
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
89
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
910
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
1011
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
@@ -17,6 +18,7 @@ Aqua = "0.8"
1718
ComponentArrays = "0.15.11"
1819
DifferentiationInterface = "0.6"
1920
ForwardDiff = "0.10.36"
21+
IntervalSets = "0.7.10"
2022
JET = "0.8, 0.9"
2123
Lux = "1"
2224
LuxCore = "1"

src/ModelingToolkitNeuralNets.jl

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
module ModelingToolkitNeuralNets
22

33
using ModelingToolkit: @parameters, @named, ODESystem, t_nounits
4+
using IntervalSets: (..)
45
using ModelingToolkitStandardLibrary.Blocks: RealInputArray, RealOutputArray
56
using Symbolics: Symbolics, @register_array_symbolic, @wrapped
67
using LuxCore: stateless_apply
78
using Lux: Lux
89
using Random: Xoshiro
910
using ComponentArrays: ComponentArray
1011

11-
export NeuralNetworkBlock, multi_layer_feed_forward
12+
export NeuralNetworkBlock, SymbolicNeuralNetwork, multi_layer_feed_forward, get_network
1213

1314
include("utils.jl")
1415

@@ -56,4 +57,65 @@ function lazyconvert(T, x::Symbolics.Arr)
5657
Symbolics.array_term(convert, T, x, size = size(x))
5758
end
5859

60+
"""
61+
SymbolicNeuralNetwork(; n_input = 1, n_output = 1,
62+
chain = multi_layer_feed_forward(n_input, n_output),
63+
rng = Xoshiro(0),
64+
init_params = Lux.initialparameters(rng, chain),
65+
eltype = Float64)
66+
67+
Create symbolic parameter for a neural network and one for its parameters.
68+
Example:
69+
```
70+
chain = multi_layer_feed_forward(2, 2)
71+
NN, p = SymbolicNeuralNetwork(; chain, n_input=2, n_output=2, rng = StableRNG(42))
72+
```
73+
74+
The NN and p are symbolic parameters that can be used lates as part of a system.
75+
To get the predictions of the neural network, use
76+
```
77+
pred ~ NN(input, p)
78+
```
79+
where `pred` and `input` are a symbolic vector variable with the lengths `n_output` and `n_input`.
80+
81+
To use this outside of an equation, you can get the default values for the symbols and make a similar call
82+
```
83+
defaults(sys)[sys.NN](input, nn_p)
84+
```
85+
86+
where `sys` is a system (e.g. `ODESystem`) that contains `NN`, `input` is a vector of `n_input` length and
87+
`nn_p` is a vector representing parameter values for the neural network.
88+
89+
To get the underlying Lux model you can use `get_network(defaults(sys)[sys.NN])` or
90+
"""
91+
function SymbolicNeuralNetwork(; n_input = 1, n_output = 1,
92+
chain = multi_layer_feed_forward(n_input, n_output),
93+
rng = Xoshiro(0),
94+
init_params = Lux.initialparameters(rng, chain),
95+
eltype = Float64)
96+
ca = ComponentArray{eltype}(init_params)
97+
wrapper = StatelessApplyWrapper(chain, ca)
98+
99+
@parameters p[1:length(ca)] = Vector(ca)
100+
@parameters (NN::typeof(wrapper))(..)[1:n_output] = wrapper
101+
102+
return NN, p
103+
end
104+
105+
struct StatelessApplyWrapper{NN}
106+
lux_model::NN
107+
T::DataType
108+
end
109+
110+
function (wrapper::StatelessApplyWrapper)(input::AbstractArray, nn_p::AbstractVector)
111+
stateless_apply(get_network(wrapper), input, convert(wrapper.T, nn_p))
112+
end
113+
114+
function Base.show(io::IO, m::MIME"text/plain", wrapper::StatelessApplyWrapper)
115+
printstyled(io, "LuxCore.stateless_apply wrapper for:\n", color=:gray)
116+
show(io, m, get_network(wrapper))
117+
end
118+
119+
get_network(wrapper::StatelessApplyWrapper) = wrapper.lux_model
120+
59121
end

0 commit comments

Comments
 (0)