Skip to content

Cannot simulate single input SymbolicNeuralNetwork model #57

@TorkelE

Description

@TorkelE

@SebastianM-C

Unfortunatley I cannot get the new symbolic interface to work. This example, adapted from the tests, work:

# Fetch packages.
using ModelingToolkit, ModelingToolkitNeuralNets
using ModelingToolkit: t_nounits as t, D_nounits as D
using OrdinaryDiffEq
using StableRNGs

let
    # Make Neural network.
    chain = multi_layer_feed_forward(2, 2)
    NN, p = SymbolicNeuralNetwork(; chain, n_input = 2, n_output = 2, rng = StableRNG(42))

    # Create UDE model.
    @variables x(t) y(t)
    @parameters α δ
    eqs = [
        D(x) ~ α * x + NN([x, y], p)[1],
        D(y) ~ -δ * y + NN([x, y], p)[2]]
    @mtkbuild osys = ODESystem(eqs, ModelingToolkit.t_nounits)

    # Test UDE simulation.
    prob = ODEProblem(osys, [x => 3.1, y => 1.5], (0, 1.0), [α => 1.3 δ => 1.8])
    sol = solve(prob)
end

However, if I try to create a model with only a single input to the neural network (here changing n_input to 1, and making similar modification in the call) I get an error:

let
    # Make Neural network.
    chain = multi_layer_feed_forward(2, 2)
    NN, p = SymbolicNeuralNetwork(; chain, n_input = 1, n_output = 2, rng = StableRNG(42))

    # Create UDE model.
    @variables x(t) y(t)
    @parameters α δ
    eqs = [
        D(x) ~ α * x + NN([x], p)[1],
        D(y) ~ -δ * y + NN([x], p)[2]]
    @mtkbuild osys = ODESystem(eqs, ModelingToolkit.t_nounits)

    # Test UDE simulation.
    prob = ODEProblem(osys, [x => 3.1, y => 1.5], (0, 1.0), [α => 1.3 δ => 1.8])
    sol = solve(prob)
end
ERROR: DimensionMismatch: A has dimensions (5,2) but B has dimensions (1,1)
Stacktrace:
  [1] gemm_wrapper!(C::Matrix{…}, tA::Char, tB::Char, A::Base.ReshapedArray{…}, B::Matrix{…}, _add::LinearAlgebra.MulAddMul{…})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.11.5+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:629
  [2] generic_matmatmul!
    @ ~/.julia/juliaup/julia-1.11.5+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:381 [inlined]
  [3] _mul!
    @ ~/.julia/juliaup/julia-1.11.5+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:287 [inlined]
  [4] mul!
    @ ~/.julia/juliaup/julia-1.11.5+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:285 [inlined]
  [5] matmul_linalg_default!
    @ ~/.julia/packages/LuxLib/EnsP3/src/impl/matmul.jl:205 [inlined]
  [6] matmul_cpu_fallback!
    @ ~/.julia/packages/LuxLib/EnsP3/src/impl/matmul.jl:177 [inlined]
  [7] matmul_cpu!
    @ ~/.julia/packages/LuxLib/EnsP3/src/impl/matmul.jl:164 [inlined]
  [8] matmul!
    @ ~/.julia/packages/LuxLib/EnsP3/src/impl/matmul.jl:124 [inlined]
  [9] fused_dense!
    @ ~/.julia/packages/LuxLib/EnsP3/src/impl/dense.jl:54 [inlined]
 [10] fused_dense
    @ ~/.julia/packages/LuxLib/EnsP3/src/impl/dense.jl:42 [inlined]
 [11] fused_dense
    @ ~/.julia/packages/LuxLib/EnsP3/src/impl/dense.jl:16 [inlined]
 [12] fused_dense_bias_activation
    @ ~/.julia/packages/LuxLib/EnsP3/src/api/dense.jl:36 [inlined]
 [13] (::Lux.Dense{…})(x::Vector{…}, ps::ComponentArrays.ComponentVector{…}, st::@NamedTuple{})
    @ Lux ~/.julia/packages/Lux/CYnn3/src/layers/basic.jl:357
 [14] apply
    @ ~/.julia/packages/LuxCore/Av7WJ/src/LuxCore.jl:155 [inlined]
 [15] macro expansion
    @ ~/.julia/packages/Lux/CYnn3/src/layers/containers.jl:0 [inlined]
 [16] applychain
    @ ~/.julia/packages/Lux/CYnn3/src/layers/containers.jl:511 [inlined]
 [17] Chain
    @ ~/.julia/packages/Lux/CYnn3/src/layers/containers.jl:509 [inlined]
 [18] apply
    @ ~/.julia/packages/LuxCore/Av7WJ/src/LuxCore.jl:155 [inlined]
 [19] stateless_apply(model::Lux.Chain{…}, x::Vector{…}, ps::ComponentArrays.ComponentVector{…})
    @ LuxCore ~/.julia/packages/LuxCore/Av7WJ/src/LuxCore.jl:166
 [20] (::ModelingToolkitNeuralNets.StatelessApplyWrapper{…})(input::Vector{…}, nn_p::SubArray{…})
    @ ModelingToolkitNeuralNets ~/.julia/packages/ModelingToolkitNeuralNets/qmb5j/src/ModelingToolkitNeuralNets.jl:120
 [21] macro expansion
    @ ~/.julia/packages/SymbolicUtils/NGWJM/src/code.jl:510 [inlined]
 [22] macro expansion
    @ ~/.julia/packages/Symbolics/B6Z8m/src/build_function.jl:368 [inlined]
 [23] macro expansion
    @ ~/.julia/packages/RuntimeGeneratedFunctions/M9ZX8/src/RuntimeGeneratedFunctions.jl:163 [inlined]
 [24] macro expansion
    @ ./none:0 [inlined]
 [25] generated_callfunc
    @ ./none:0 [inlined]
 [26] (::RuntimeGeneratedFunctions.RuntimeGeneratedFunction{…})(::Vector{…}, ::Vector{…}, ::MTKParameters{…}, ::Float64)
    @ RuntimeGeneratedFunctions ~/.julia/packages/RuntimeGeneratedFunctions/M9ZX8/src/RuntimeGeneratedFunctions.jl:150
 [27] macro expansion
    @ ~/.julia/packages/ModelingToolkit/aau6A/src/systems/codegen_utils.jl:0 [inlined]
 [28] _generated_call
    @ ~/.julia/packages/ModelingToolkit/aau6A/src/systems/codegen_utils.jl:262 [inlined]
 [29] GeneratedFunctionWrapper
    @ ~/.julia/packages/ModelingToolkit/aau6A/src/systems/codegen_utils.jl:259 [inlined]
 [30] Void
    @ ~/.julia/packages/SciMLBase/c6Noy/src/utils.jl:486 [inlined]
 [31] (::FunctionWrappers.CallWrapper{…})(f::SciMLBase.Void{…}, arg1::Vector{…}, arg2::Vector{…}, arg3::MTKParameters{…}, arg4::Float64)
    @ FunctionWrappers ~/.julia/packages/FunctionWrappers/Q5cBx/src/FunctionWrappers.jl:65
 [32] macro expansion
    @ ~/.julia/packages/FunctionWrappers/Q5cBx/src/FunctionWrappers.jl:137 [inlined]
 [33] do_ccall
    @ ~/.julia/packages/FunctionWrappers/Q5cBx/src/FunctionWrappers.jl:125 [inlined]
 [34] FunctionWrapper
    @ ~/.julia/packages/FunctionWrappers/Q5cBx/src/FunctionWrappers.jl:144 [inlined]
 [35] _call
    @ ~/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:12 [inlined]
 [36] FunctionWrappersWrapper
    @ ~/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:10 [inlined]
--- the above 7 lines are repeated 1 more time ---
 [44] ODEFunction
    @ ~/.julia/packages/SciMLBase/c6Noy/src/scimlfunctions.jl:2470 [inlined]
 [45] initialize!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…}, cache::OrdinaryDiffEqTsit5.Tsit5Cache{…})
    @ OrdinaryDiffEqTsit5 ~/.julia/packages/OrdinaryDiffEqTsit5/DHYtz/src/tsit_perform_step.jl:175
 [46] initialize!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…}, cache::OrdinaryDiffEqCore.DefaultCache{…})
    @ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/bMOsj/src/perform_step/composite_perform_step.jl:38
 [47] __init(prob::ODEProblem{…}, alg::CompositeAlgorithm{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::Float64, dtmin::Float64, dtmax::Float64, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Nothing, reltol::Nothing, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias::ODEAliasSpecifier, initializealg::OrdinaryDiffEqCore.DefaultInit, kwargs::@Kwargs{…})
    @ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/bMOsj/src/solve.jl:587
 [48] __init (repeats 5 times)
    @ ~/.julia/packages/OrdinaryDiffEqCore/bMOsj/src/solve.jl:11 [inlined]
 [49] #__solve#62
    @ ~/.julia/packages/OrdinaryDiffEqCore/bMOsj/src/solve.jl:6 [inlined]
 [50] __solve
    @ ~/.julia/packages/OrdinaryDiffEqCore/bMOsj/src/solve.jl:1 [inlined]
 [51] solve_call(_prob::ODEProblem{…}, args::CompositeAlgorithm{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/zYZst/src/solve.jl:635
 [52] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::MTKParameters{…}, args::CompositeAlgorithm{…}; kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/zYZst/src/solve.jl:1142
 [53] solve_up
    @ ~/.julia/packages/DiffEqBase/zYZst/src/solve.jl:1120 [inlined]
 [54] #solve#42
    @ ~/.julia/packages/DiffEqBase/zYZst/src/solve.jl:1059 [inlined]
 [55] solve
    @ ~/.julia/packages/DiffEqBase/zYZst/src/solve.jl:1047 [inlined]
 [56] #__solve#3
    @ ~/.julia/packages/OrdinaryDiffEqDefault/kOV75/src/default_alg.jl:48 [inlined]
 [57] __solve
    @ ~/.julia/packages/OrdinaryDiffEqDefault/kOV75/src/default_alg.jl:47 [inlined]
 [58] #__solve#63
    @ ~/.julia/packages/DiffEqBase/zYZst/src/solve.jl:1451 [inlined]
 [59] __solve
    @ ~/.julia/packages/DiffEqBase/zYZst/src/solve.jl:1442 [inlined]
 [60] #solve_call#35
    @ ~/.julia/packages/DiffEqBase/zYZst/src/solve.jl:635 [inlined]
 [61] solve_call(::ODEProblem{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/zYZst/src/solve.jl:592
 [62] solve_up(::ODEProblem{…}, ::Nothing, ::Vector{…}, ::MTKParameters{…}; kwargs::@Kwargs{})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/zYZst/src/solve.jl:1126
 [63] solve_up
    @ ~/.julia/packages/DiffEqBase/zYZst/src/solve.jl:1120 [inlined]
 [64] solve(::ODEProblem{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/zYZst/src/solve.jl:1057
 [65] solve(::ODEProblem{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/zYZst/src/solve.jl:1047
Some type information was truncated. Use `show(err)` to see complete types.

(I originally I had a different model where I tried to get SymbolicNeuralNetwork to work, but that failed similarly, but started working of I increased n_input to 2).

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions