Skip to content

Commit 6bdf82a

Browse files
Merge pull request #84 from ChrisRackauckas-Claude/add-scalar-dispatch-fix-83
Add scalar dispatch for SymbolicNeuralNetwork (fixes #83)
2 parents ed430cc + a5b8ebc commit 6bdf82a

File tree

3 files changed

+76
-0
lines changed

3 files changed

+76
-0
lines changed

src/ModelingToolkitNeuralNets.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,10 @@ function (wrapper::StatelessApplyWrapper)(input::AbstractArray, nn_p::AbstractVe
121121
stateless_apply(get_network(wrapper), input, convert(wrapper.T, nn_p))
122122
end
123123

124+
function (wrapper::StatelessApplyWrapper)(input::Number, nn_p::AbstractVector)
125+
wrapper([input], nn_p)
126+
end
127+
124128
function Base.show(io::IO, m::MIME"text/plain", wrapper::StatelessApplyWrapper)
125129
printstyled(io, "LuxCore.stateless_apply wrapper for:\n", color = :gray)
126130
show(io, m, get_network(wrapper))

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ using SafeTestsets
66
@safetestset "QA" include("qa.jl")
77
@safetestset "Basic" include("lotka_volterra.jl")
88
@safetestset "MTK model macro compatibility" include("macro.jl")
9+
@safetestset "Scalar dispatch" include("scalar_dispatch.jl")
910
end

test/scalar_dispatch.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
using Test
2+
using ModelingToolkitNeuralNets
3+
using ModelingToolkit
4+
using Lux
5+
using StableRNGs
6+
using OrdinaryDiffEqVerner
7+
8+
# Test scalar dispatch for SymbolicNeuralNetwork
9+
# This tests the fix for issue #83
10+
@testset "Scalar dispatch" begin
11+
# Create a simple UDE with scalar inputs
12+
@variables t X(t) Y(t)
13+
@parameters d
14+
15+
chain = Lux.Chain(
16+
Lux.Dense(1 => 3, Lux.softplus, use_bias = false),
17+
Lux.Dense(3 => 3, Lux.softplus, use_bias = false),
18+
Lux.Dense(3 => 1, Lux.softplus, use_bias = false)
19+
)
20+
21+
sym_nn,
22+
θ = SymbolicNeuralNetwork(;
23+
nn_p_name = , chain, n_input = 1, n_output = 1, rng = StableRNG(42))
24+
25+
# Test that scalar dispatch works (fix for issue #83)
26+
# Previously required: sym_nn([Y], θ)[1]
27+
# Now can use: sym_nn(Y, θ)[1]
28+
Dt = ModelingToolkit.D_nounits
29+
eqs_ude = [
30+
Dt(X) ~ sym_nn(Y, θ)[1] - d*X,
31+
Dt(Y) ~ X - d*Y
32+
]
33+
34+
@named sys = System(eqs_ude, ModelingToolkit.t_nounits)
35+
sys_compiled = mtkcompile(sys)
36+
37+
# Test that the system can be created and solved
38+
prob = ODEProblem{true, SciMLBase.FullSpecialize}(
39+
sys_compiled,
40+
[X => 1.0, Y => 1.0],
41+
(0.0, 1.0),
42+
[d => 0.1]
43+
)
44+
45+
sol = solve(prob, Vern9(), abstol = 1e-8, reltol = 1e-8)
46+
47+
@test SciMLBase.successful_retcode(sol)
48+
49+
# Also test that the old array syntax still works
50+
eqs_ude_old = [
51+
Dt(X) ~ sym_nn([Y], θ)[1] - d*X,
52+
Dt(Y) ~ X - d*Y
53+
]
54+
55+
@named sys_old = System(eqs_ude_old, ModelingToolkit.t_nounits)
56+
sys_old_compiled = mtkcompile(sys_old)
57+
58+
prob_old = ODEProblem{true, SciMLBase.FullSpecialize}(
59+
sys_old_compiled,
60+
[X => 1.0, Y => 1.0],
61+
(0.0, 1.0),
62+
[d => 0.1]
63+
)
64+
65+
sol_old = solve(prob_old, Vern9(), abstol = 1e-8, reltol = 1e-8)
66+
67+
@test SciMLBase.successful_retcode(sol_old)
68+
69+
# Both solutions should be the same
70+
@test sol.u sol_old.u
71+
end

0 commit comments

Comments
 (0)