Skip to content

Commit fc9f1e7

Browse files
committed
feat: add n_input and n_output as keyword arguments
This is needed for compatibility with `@mtkmodel`
1 parent 9d4aa15 commit fc9f1e7

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

src/ModelingToolkitNeuralNets.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ export NeuralNetworkBlock, multi_layer_feed_forward
1313
include("utils.jl")
1414

1515
"""
16-
NeuralNetworkBlock(n_input = 1, n_output = 1;
16+
NeuralNetworkBlock(; n_input = 1, n_output = 1,
1717
chain = multi_layer_feed_forward(n_input, n_output),
1818
rng = Xoshiro(0),
1919
init_params = Lux.initialparameters(rng, chain),
@@ -22,8 +22,7 @@ include("utils.jl")
2222
2323
Create an `ODESystem` with a neural network inside.
2424
"""
25-
function NeuralNetworkBlock(n_input = 1,
26-
n_output = 1;
25+
function NeuralNetworkBlock(; n_input = 1, n_output = 1,
2726
chain = multi_layer_feed_forward(n_input, n_output),
2827
rng = Xoshiro(0),
2928
init_params = Lux.initialparameters(rng, chain),
@@ -46,6 +45,12 @@ function NeuralNetworkBlock(n_input = 1,
4645
return ude_comp
4746
end
4847

48+
# added to avoid a breaking change from moving n_input & n_output in kwargs
49+
# https://github.com/SciML/ModelingToolkitNeuralNets.jl/issues/32
50+
function NeuralNetworkBlock(n_input, n_output = 1; kwargs...)
51+
NeuralNetworkBlock(; n_input, n_output, kwargs...)
52+
end
53+
4954
function lazyconvert(T, x::Symbolics.Arr)
5055
Symbolics.array_term(convert, T, x, size = size(x))
5156
end

0 commit comments

Comments
 (0)