Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/ModelingToolkitNeuralNets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ function (wrapper::StatelessApplyWrapper)(input::AbstractArray, nn_p::AbstractVe
stateless_apply(get_network(wrapper), input, convert(wrapper.T, nn_p))
end

function (wrapper::StatelessApplyWrapper)(input::Number, nn_p::AbstractVector)
wrapper([input], nn_p)
end

function Base.show(io::IO, m::MIME"text/plain", wrapper::StatelessApplyWrapper)
printstyled(io, "LuxCore.stateless_apply wrapper for:\n", color = :gray)
show(io, m, get_network(wrapper))
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ using SafeTestsets
@safetestset "QA" include("qa.jl")
@safetestset "Basic" include("lotka_volterra.jl")
@safetestset "MTK model macro compatibility" include("macro.jl")
@safetestset "Scalar dispatch" include("scalar_dispatch.jl")
end
71 changes: 71 additions & 0 deletions test/scalar_dispatch.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
using Test
using ModelingToolkitNeuralNets
using ModelingToolkit
using Lux
using StableRNGs
using OrdinaryDiffEqVerner

# Test scalar dispatch for SymbolicNeuralNetwork
# This tests the fix for issue #83
@testset "Scalar dispatch" begin
# Create a simple UDE with scalar inputs
@variables t X(t) Y(t)
@parameters d

chain = Lux.Chain(
Lux.Dense(1 => 3, Lux.softplus, use_bias = false),
Lux.Dense(3 => 3, Lux.softplus, use_bias = false),
Lux.Dense(3 => 1, Lux.softplus, use_bias = false)
)

sym_nn,
θ = SymbolicNeuralNetwork(;
nn_p_name = :θ, chain, n_input = 1, n_output = 1, rng = StableRNG(42))

# Test that scalar dispatch works (fix for issue #83)
# Previously required: sym_nn([Y], θ)[1]
# Now can use: sym_nn(Y, θ)[1]
Dt = ModelingToolkit.D_nounits
eqs_ude = [
Dt(X) ~ sym_nn(Y, θ)[1] - d*X,
Dt(Y) ~ X - d*Y
]

@named sys = System(eqs_ude, ModelingToolkit.t_nounits)
sys_compiled = mtkcompile(sys)

# Test that the system can be created and solved
prob = ODEProblem{true, SciMLBase.FullSpecialize}(
sys_compiled,
[X => 1.0, Y => 1.0],
(0.0, 1.0),
[d => 0.1]
)

sol = solve(prob, Vern9(), abstol = 1e-8, reltol = 1e-8)

@test SciMLBase.successful_retcode(sol)

# Also test that the old array syntax still works
eqs_ude_old = [
Dt(X) ~ sym_nn([Y], θ)[1] - d*X,
Dt(Y) ~ X - d*Y
]

@named sys_old = System(eqs_ude_old, ModelingToolkit.t_nounits)
sys_old_compiled = mtkcompile(sys_old)

prob_old = ODEProblem{true, SciMLBase.FullSpecialize}(
sys_old_compiled,
[X => 1.0, Y => 1.0],
(0.0, 1.0),
[d => 0.1]
)

sol_old = solve(prob_old, Vern9(), abstol = 1e-8, reltol = 1e-8)

@test SciMLBase.successful_retcode(sol_old)

# Both solutions should be the same
@test sol.u ≈ sol_old.u
end
Loading