Skip to content

Commit b182114

Browse files
committed
feat: add SymbolicNeuralNetwork as a callable parameter interface
1 parent 774aba9 commit b182114

File tree

2 files changed

+69
-1
lines changed

2 files changed

+69
-1
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "1.6.1"
55

66
[deps]
77
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
8+
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
89
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
910
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
1011
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
@@ -17,6 +18,7 @@ Aqua = "0.8"
1718
ComponentArrays = "0.15.11"
1819
DifferentiationInterface = "0.6"
1920
ForwardDiff = "0.10.36"
21+
IntervalSets = "0.7.10"
2022
JET = "0.8, 0.9"
2123
Lux = "1"
2224
LuxCore = "1"

src/ModelingToolkitNeuralNets.jl

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
module ModelingToolkitNeuralNets
22

33
using ModelingToolkit: @parameters, @named, ODESystem, t_nounits
4+
using IntervalSets: (..)
45
using ModelingToolkitStandardLibrary.Blocks: RealInputArray, RealOutputArray
56
using Symbolics: Symbolics, @register_array_symbolic, @wrapped
67
using LuxCore: stateless_apply
78
using Lux: Lux
89
using Random: Xoshiro
910
using ComponentArrays: ComponentArray
1011

11-
export NeuralNetworkBlock, multi_layer_feed_forward
12+
export NeuralNetworkBlock, SymbolicNeuralNetwork, multi_layer_feed_forward, get_network
1213

1314
include("utils.jl")
1415

@@ -56,4 +57,69 @@ function lazyconvert(T, x::Symbolics.Arr)
5657
Symbolics.array_term(convert, T, x, size = size(x))
5758
end
5859

60+
"""
61+
SymbolicNeuralNetwork(; n_input = 1, n_output = 1,
62+
chain = multi_layer_feed_forward(n_input, n_output),
63+
rng = Xoshiro(0),
64+
init_params = Lux.initialparameters(rng, chain),
65+
eltype = Float64)
66+
67+
Create symbolic parameter for a neural network and one for its parameters.
68+
Example:
69+
70+
```
71+
chain = multi_layer_feed_forward(2, 2)
72+
NN, p = SymbolicNeuralNetwork(; chain, n_input=2, n_output=2, rng = StableRNG(42))
73+
```
74+
75+
The NN and p are symbolic parameters that can be used later as part of a system.
76+
To get the predictions of the neural network, use
77+
78+
```
79+
pred ~ NN(input, p)
80+
```
81+
82+
where `pred` and `input` are a symbolic vector variable with the lengths `n_output` and `n_input`.
83+
84+
To use this outside of an equation, you can get the default values for the symbols and make a similar call
85+
86+
```
87+
defaults(sys)[sys.NN](input, nn_p)
88+
```
89+
90+
where `sys` is a system (e.g. `ODESystem`) that contains `NN`, `input` is a vector of `n_input` length and
91+
`nn_p` is a vector representing parameter values for the neural network.
92+
93+
To get the underlying Lux model you can use `get_network(defaults(sys)[sys.NN])` or
94+
"""
95+
function SymbolicNeuralNetwork(; n_input = 1, n_output = 1,
96+
chain = multi_layer_feed_forward(n_input, n_output),
97+
rng = Xoshiro(0),
98+
init_params = Lux.initialparameters(rng, chain),
99+
eltype = Float64)
100+
ca = ComponentArray{eltype}(init_params)
101+
wrapper = StatelessApplyWrapper(chain, typeof(ca))
102+
103+
@parameters p[1:length(ca)] = Vector(ca)
104+
@parameters (NN::typeof(wrapper))(..)[1:n_output] = wrapper
105+
106+
return NN, p
107+
end
108+
109+
struct StatelessApplyWrapper{NN}
110+
lux_model::NN
111+
T::DataType
112+
end
113+
114+
function (wrapper::StatelessApplyWrapper)(input::AbstractArray, nn_p::AbstractVector)
115+
stateless_apply(get_network(wrapper), input, convert(wrapper.T, nn_p))
116+
end
117+
118+
function Base.show(io::IO, m::MIME"text/plain", wrapper::StatelessApplyWrapper)
119+
printstyled(io, "LuxCore.stateless_apply wrapper for:\n", color = :gray)
120+
show(io, m, get_network(wrapper))
121+
end
122+
123+
get_network(wrapper::StatelessApplyWrapper) = wrapper.lux_model
124+
59125
end

0 commit comments

Comments
 (0)