1
1
module ModelingToolkitNeuralNets
2
2
3
- using ModelingToolkit: @parameters , @named , ODESystem, t_nounits, @connector , @variables ,
4
- Equation
5
- using ModelingToolkitStandardLibrary. Blocks: RealInput, RealOutput
3
+ using ModelingToolkit: @parameters , @named , ODESystem, t_nounits
4
+ using ModelingToolkitStandardLibrary. Blocks: RealInputArray, RealOutputArray
6
5
using Symbolics: Symbolics, @register_array_symbolic , @wrapped
7
6
using LuxCore: stateless_apply
8
7
using Lux: Lux
@@ -13,24 +12,6 @@ export NeuralNetworkBlock, multi_layer_feed_forward
13
12
14
13
include (" utils.jl" )
15
14
16
- @connector function RealInput2 (; name, nin = 1 , u_start = zeros (nin))
17
- @variables u (t_nounits)[1 : nin]= u_start [
18
- input = true ,
19
- description = " Inner variable in RealInput $name "
20
- ]
21
- u = collect (u)
22
- ODESystem (Equation[], t_nounits, [u... ], []; name = name)
23
- end
24
-
25
- @connector function RealOutput2 (; name, nout = 1 , u_start = zeros (nout))
26
- @variables u (t_nounits)[1 : nout]= u_start [
27
- output = true ,
28
- description = " Inner variable in RealOutput $name "
29
- ]
30
- u = collect (u)
31
- ODESystem (Equation[], t_nounits, [u... ], []; name = name)
32
- end
33
-
34
15
"""
35
16
NeuralNetworkBlock(n_input = 1, n_output = 1;
36
17
chain = multi_layer_feed_forward(n_input, n_output),
@@ -49,12 +30,12 @@ function NeuralNetworkBlock(n_input = 1,
49
30
ca = ComponentArray {eltype} (init_params)
50
31
51
32
@parameters p[1 : length (ca)] = Vector (ca)
52
- @parameters T:: typeof (typeof (p ))= typeof (p ) [tunable = false ]
33
+ @parameters T:: typeof (typeof (ca ))= typeof (ca ) [tunable = false ]
53
34
54
- @named input = RealInput2 (nin = n_input)
55
- @named output = RealOutput2 (nout = n_output)
35
+ @named input = RealInputArray (nin = n_input)
36
+ @named output = RealOutputArray (nout = n_output)
56
37
57
- out = stateless_apply (chain, input. u, lazyconvert (typeof (ca) , p))
38
+ out = stateless_apply (chain, input. u, lazyconvert (T , p))
58
39
59
40
eqs = [output. u ~ out]
60
41
0 commit comments