Skip to content

Commit 4b37aba

Browse files
Merge pull request #24 from SciML/ChrisRackauckas-patch-1
Missing name in NueralNetworkBlock
2 parents 2d6db67 + 7ad050c commit 4b37aba

File tree

3 files changed

+8
-6
lines changed

3 files changed

+8
-6
lines changed

docs/src/friction.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ chain = Lux.Chain(
9696
Lux.Dense(10 => 10, Lux.mish, use_bias = false),
9797
Lux.Dense(10 => 1, use_bias = false)
9898
)
99-
nn = NeuralNetworkBlock(1, 1; chain = chain, rng = StableRNG(1111))
99+
@named nn = NeuralNetworkBlock(1, 1; chain = chain, rng = StableRNG(1111))
100100
101101
eqs = [connect(model.nn_in, nn.output)
102102
connect(model.nn_out, nn.input)]

src/ModelingToolkitNeuralNets.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ include("utils.jl")
1717
chain = multi_layer_feed_forward(n_input, n_output),
1818
rng = Xoshiro(0),
1919
init_params = Lux.initialparameters(rng, chain),
20-
eltype = Float64)
20+
eltype = Float64,
21+
name)
2122
2223
Create an `ODESystem` with a neural network inside.
2324
"""
@@ -26,7 +27,8 @@ function NeuralNetworkBlock(n_input = 1,
2627
chain = multi_layer_feed_forward(n_input, n_output),
2728
rng = Xoshiro(0),
2829
init_params = Lux.initialparameters(rng, chain),
29-
eltype = Float64)
30+
eltype = Float64,
31+
name)
3032
ca = ComponentArray{eltype}(init_params)
3133

3234
@parameters p[1:length(ca)] = Vector(ca)
@@ -39,8 +41,8 @@ function NeuralNetworkBlock(n_input = 1,
3941

4042
eqs = [output.u ~ out]
4143

42-
@named ude_comp = ODESystem(
43-
eqs, t_nounits, [], [p, T], systems = [input, output])
44+
ude_comp = ODESystem(
45+
eqs, t_nounits, [], [p, T]; systems = [input, output], name)
4446
return ude_comp
4547
end
4648

test/lotka_volterra.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ end
4444
model = lotka_ude()
4545

4646
chain = multi_layer_feed_forward(2, 2)
47-
nn = NeuralNetworkBlock(2, 2; chain, rng = StableRNG(42))
47+
@named nn = NeuralNetworkBlock(2, 2; chain, rng = StableRNG(42))
4848

4949
eqs = [connect(model.nn_in, nn.output)
5050
connect(model.nn_out, nn.input)]

0 commit comments

Comments
 (0)