Skip to content

Commit 84c173b

Browse files
committed
Consolidate reported issues
1 parent 69ffd14 commit 84c173b

File tree

2 files changed

+36
-20
lines changed

2 files changed

+36
-20
lines changed
Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,7 @@
1-
using Test
2-
using ModelingToolkitNeuralNets
3-
using ModelingToolkit
4-
using Lux
5-
using StableRNGs
1+
using Test, Lux, ModelingToolkitNeuralNets, StableRNGs, ModelingToolkit
62
using OrdinaryDiffEqVerner
73

8-
# Test scalar dispatch for SymbolicNeuralNetwork
9-
# This tests the fix for issue #83
10-
@testset "Scalar dispatch" begin
4+
@testset "Scalar dispatch (issue #83)" begin
115
# Create a simple UDE with scalar inputs
126
@variables t X(t) Y(t)
137
@parameters d
@@ -27,8 +21,8 @@ using OrdinaryDiffEqVerner
2721
# Now can use: sym_nn(Y, θ)[1]
2822
Dt = ModelingToolkit.D_nounits
2923
eqs_ude = [
30-
Dt(X) ~ sym_nn(Y, θ)[1] - d*X,
31-
Dt(Y) ~ X - d*Y
24+
Dt(X) ~ sym_nn(Y, θ)[1] - d * X,
25+
Dt(Y) ~ X - d * Y
3226
]
3327

3428
@named sys = System(eqs_ude, ModelingToolkit.t_nounits)
@@ -37,9 +31,8 @@ using OrdinaryDiffEqVerner
3731
# Test that the system can be created and solved
3832
prob = ODEProblem{true, SciMLBase.FullSpecialize}(
3933
sys_compiled,
40-
[X => 1.0, Y => 1.0],
41-
(0.0, 1.0),
42-
[d => 0.1]
34+
[X => 1.0, Y => 1.0, d => 0.1],
35+
(0.0, 1.0)
4336
)
4437

4538
sol = solve(prob, Vern9(), abstol = 1e-8, reltol = 1e-8)
@@ -48,24 +41,47 @@ using OrdinaryDiffEqVerner
4841

4942
# Also test that the old array syntax still works
5043
eqs_ude_old = [
51-
Dt(X) ~ sym_nn([Y], θ)[1] - d*X,
52-
Dt(Y) ~ X - d*Y
44+
Dt(X) ~ sym_nn([Y], θ)[1] - d * X,
45+
Dt(Y) ~ X - d * Y
5346
]
5447

5548
@named sys_old = System(eqs_ude_old, ModelingToolkit.t_nounits)
5649
sys_old_compiled = mtkcompile(sys_old)
5750

5851
prob_old = ODEProblem{true, SciMLBase.FullSpecialize}(
5952
sys_old_compiled,
60-
[X => 1.0, Y => 1.0],
61-
(0.0, 1.0),
62-
[d => 0.1]
53+
[X => 1.0, Y => 1.0, d => 0.1],
54+
(0.0, 1.0)
6355
)
6456

6557
sol_old = solve(prob_old, Vern9(), abstol = 1e-8, reltol = 1e-8)
6658

6759
@test SciMLBase.successful_retcode(sol_old)
6860

6961
# Both solutions should be the same
70-
@test sol.u sol_old.u
62+
@test sol.u == sol_old.u
63+
end
64+
65+
@testset "Issue #58" begin
66+
# Preparation
67+
rng = StableRNG(123)
68+
chain = Lux.Chain(
69+
Lux.Dense(1 => 3, Lux.softplus, use_bias = false),
70+
Lux.Dense(3 => 3, Lux.softplus, use_bias = false),
71+
Lux.Dense(3 => 1, Lux.sigmoid_fast, use_bias = false)
72+
)
73+
74+
# Default names.
75+
NN, NN_p = SymbolicNeuralNetwork(; chain, n_input = 1, n_output = 1, rng)
76+
@test ModelingToolkit.getname(NN) == :nn_name
77+
@test ModelingToolkit.getname(NN_p) == :p
78+
79+
# Trying to set specific names.
80+
nn_name = :custom_nn_name
81+
nn_p_name = :custom_nn_p_name
82+
NN, NN_p = SymbolicNeuralNetwork(;
83+
chain, n_input = 1, n_output = 1, rng, nn_name, nn_p_name)
84+
85+
@test ModelingToolkit.getname(NN)==nn_name broken=true # :nn_name # Should be :custom_nn_name
86+
@test ModelingToolkit.getname(NN_p) == nn_p_name
7187
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@ using SafeTestsets
66
@safetestset "QA" include("qa.jl")
77
@safetestset "Basic" include("lotka_volterra.jl")
88
@safetestset "MTK model macro compatibility" include("macro.jl")
9-
@safetestset "Scalar dispatch" include("scalar_dispatch.jl")
9+
@safetestset "Reported issues" include("reported_issues.jl")
1010
end

0 commit comments

Comments
 (0)