62
62
chain = multi_layer_feed_forward(n_input, n_output),
63
63
rng = Xoshiro(0),
64
64
init_params = Lux.initialparameters(rng, chain),
65
+ nn_name = :NN,
66
+ nn_p_name = :p,
65
67
eltype = Float64)
66
68
67
69
Create symbolic parameter for a neural network and one for its parameters.
@@ -73,6 +75,7 @@ NN, p = SymbolicNeuralNetwork(; chain, n_input=2, n_output=2, rng = StableRNG(42
73
75
```
74
76
75
77
The NN and p are symbolic parameters that can be used later as part of a system.
78
+ To change the name of the symbolic variables, use `nn_name` and `nn_p_name`.
76
79
To get the predictions of the neural network, use
77
80
78
81
```
@@ -96,14 +99,16 @@ function SymbolicNeuralNetwork(; n_input = 1, n_output = 1,
96
99
chain = multi_layer_feed_forward (n_input, n_output),
97
100
rng = Xoshiro (0 ),
98
101
init_params = Lux. initialparameters (rng, chain),
102
+ nn_name = :NN ,
103
+ nn_p_name = :p ,
99
104
eltype = Float64)
100
105
ca = ComponentArray {eltype} (init_params)
101
106
wrapper = StatelessApplyWrapper (chain, typeof (ca))
102
107
103
- @parameters p [1 : length (ca)] = Vector (ca)
104
- @parameters (NN :: typeof (wrapper))(.. )[1 : n_output] = wrapper
108
+ p = @parameters $ (nn_p_name) [1 : length (ca)] = Vector (ca)
109
+ NN = @parameters ($ (nn_name) :: typeof (wrapper))(.. )[1 : n_output] = wrapper
105
110
106
- return NN, p
111
+ return only (NN), only (p)
107
112
end
108
113
109
114
struct StatelessApplyWrapper{NN}
0 commit comments