6262 chain = multi_layer_feed_forward(n_input, n_output),
6363 rng = Xoshiro(0),
6464 init_params = Lux.initialparameters(rng, chain),
65+ nn_name = :NN,
66+ nn_p_name = :p,
6567 eltype = Float64)
6668
6769Create symbolic parameter for a neural network and one for its parameters.
@@ -73,6 +75,7 @@ NN, p = SymbolicNeuralNetwork(; chain, n_input=2, n_output=2, rng = StableRNG(42
7375```
7476
7577The NN and p are symbolic parameters that can be used later as part of a system.
78+ To change the name of the symbolic variables, use `nn_name` and `nn_p_name`.
7679To get the predictions of the neural network, use
7780
7881```
@@ -96,14 +99,16 @@ function SymbolicNeuralNetwork(; n_input = 1, n_output = 1,
9699 chain = multi_layer_feed_forward (n_input, n_output),
97100 rng = Xoshiro (0 ),
98101 init_params = Lux. initialparameters (rng, chain),
102+ nn_name = :NN ,
103+ nn_p_name = :p ,
99104 eltype = Float64)
100105 ca = ComponentArray {eltype} (init_params)
101106 wrapper = StatelessApplyWrapper (chain, typeof (ca))
102107
103- @parameters p [1 : length (ca)] = Vector (ca)
104- @parameters (NN :: typeof (wrapper))(.. )[1 : n_output] = wrapper
108+ p = @parameters $ (nn_p_name) [1 : length (ca)] = Vector (ca)
109+ NN = @parameters ($ (nn_name) :: typeof (wrapper))(.. )[1 : n_output] = wrapper
105110
106- return NN, p
111+ return only (NN), only (p)
107112end
108113
109114struct StatelessApplyWrapper{NN}
0 commit comments