Skip to content

Commit cc3b075

Browse files
committed
refactor: avoid hardcoding symbolic var names
1 parent 3979b21 commit cc3b075

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

src/ModelingToolkitNeuralNets.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ end
6262
chain = multi_layer_feed_forward(n_input, n_output),
6363
rng = Xoshiro(0),
6464
init_params = Lux.initialparameters(rng, chain),
65+
nn_name = :NN,
66+
nn_p_name = :p,
6567
eltype = Float64)
6668
6769
Create symbolic parameter for a neural network and one for its parameters.
@@ -73,6 +75,7 @@ NN, p = SymbolicNeuralNetwork(; chain, n_input=2, n_output=2, rng = StableRNG(42
7375
```
7476
7577
The NN and p are symbolic parameters that can be used later as part of a system.
78+
To change the name of the symbolic variables, use `nn_name` and `nn_p_name`.
7679
To get the predictions of the neural network, use
7780
7881
```
@@ -96,14 +99,16 @@ function SymbolicNeuralNetwork(; n_input = 1, n_output = 1,
9699
chain = multi_layer_feed_forward(n_input, n_output),
97100
rng = Xoshiro(0),
98101
init_params = Lux.initialparameters(rng, chain),
102+
nn_name = :NN,
103+
nn_p_name = :p,
99104
eltype = Float64)
100105
ca = ComponentArray{eltype}(init_params)
101106
wrapper = StatelessApplyWrapper(chain, typeof(ca))
102107

103-
@parameters p[1:length(ca)] = Vector(ca)
104-
@parameters (NN::typeof(wrapper))(..)[1:n_output] = wrapper
108+
p = @parameters $(nn_p_name)[1:length(ca)] = Vector(ca)
109+
NN = @parameters ($(nn_name)::typeof(wrapper))(..)[1:n_output] = wrapper
105110

106-
return NN, p
111+
return only(NN), only(p)
107112
end
108113

109114
struct StatelessApplyWrapper{NN}

0 commit comments

Comments
 (0)