Skip to content

Commit 578a561

Browse files
author
Sathvik Bhagavan
committed
refactor: use RealInputArray and RealOutputArray
1 parent 02041f1 commit 578a561

File tree

1 file changed

+6
-25
lines changed

1 file changed

+6
-25
lines changed

src/ModelingToolkitNeuralNets.jl

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
module ModelingToolkitNeuralNets
22

3-
using ModelingToolkit: @parameters, @named, ODESystem, t_nounits, @connector, @variables,
4-
Equation
5-
using ModelingToolkitStandardLibrary.Blocks: RealInput, RealOutput
3+
using ModelingToolkit: @parameters, @named, ODESystem, t_nounits
4+
using ModelingToolkitStandardLibrary.Blocks: RealInputArray, RealOutputArray
65
using Symbolics: Symbolics, @register_array_symbolic, @wrapped
76
using LuxCore: stateless_apply
87
using Lux: Lux
@@ -13,24 +12,6 @@ export NeuralNetworkBlock, multi_layer_feed_forward
1312

1413
include("utils.jl")
1514

16-
@connector function RealInput2(; name, nin = 1, u_start = zeros(nin))
17-
@variables u(t_nounits)[1:nin]=u_start [
18-
input = true,
19-
description = "Inner variable in RealInput $name"
20-
]
21-
u = collect(u)
22-
ODESystem(Equation[], t_nounits, [u...], []; name = name)
23-
end
24-
25-
@connector function RealOutput2(; name, nout = 1, u_start = zeros(nout))
26-
@variables u(t_nounits)[1:nout]=u_start [
27-
output = true,
28-
description = "Inner variable in RealOutput $name"
29-
]
30-
u = collect(u)
31-
ODESystem(Equation[], t_nounits, [u...], []; name = name)
32-
end
33-
3415
"""
3516
NeuralNetworkBlock(n_input = 1, n_output = 1;
3617
chain = multi_layer_feed_forward(n_input, n_output),
@@ -49,12 +30,12 @@ function NeuralNetworkBlock(n_input = 1,
4930
ca = ComponentArray{eltype}(init_params)
5031

5132
@parameters p[1:length(ca)] = Vector(ca)
52-
@parameters T::typeof(typeof(p))=typeof(p) [tunable = false]
33+
@parameters T::typeof(typeof(ca))=typeof(ca) [tunable = false]
5334

54-
@named input = RealInput2(nin = n_input)
55-
@named output = RealOutput2(nout = n_output)
35+
@named input = RealInputArray(nin = n_input)
36+
@named output = RealOutputArray(nout = n_output)
5637

57-
out = stateless_apply(chain, input.u, lazyconvert(typeof(ca), p))
38+
out = stateless_apply(chain, input.u, lazyconvert(T, p))
5839

5940
eqs = [output.u ~ out]
6041

0 commit comments

Comments
 (0)