diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index d0b687c212..ae4feec62b 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -388,8 +388,8 @@ function IndexCache(sys::AbstractSystem) observed_syms_to_timeseries, dependent_pars_to_timeseries, disc_buffer_templates, - BufferTemplate(Real, tunable_buffer_size), - BufferTemplate(Real, initials_buffer_size), + BufferTemplate(Number, tunable_buffer_size), + BufferTemplate(Number, initials_buffer_size), const_buffer_sizes, nonnumeric_buffer_sizes, symbol_to_variable diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index b893d9ffc6..ec72e41fe7 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -962,7 +962,7 @@ end $(TYPEDEF) A callable struct to use as the `get_updated_u0` field of `InitializationMetadata`. -Returns the value to use for the `u0` of the problem. +Returns the value to use for the `u0` of the problem. # Fields @@ -1160,7 +1160,7 @@ function float_type_from_varmap(varmap, floatT = Bool) if v isa AbstractArray floatT = promote_type(floatT, eltype(v)) - elseif v isa Real + elseif v isa Number floatT = promote_type(floatT, typeof(v)) end end @@ -1432,7 +1432,7 @@ function check_inputmap_keys(sys, u0map, pmap) end const BAD_KEY_MESSAGE = """ - Undefined keys found in the parameter or initial condition maps. Check if symbolic variable names have been reassigned. + Undefined keys found in the parameter or initial condition maps. Check if symbolic variable names have been reassigned. The following keys are invalid: """ diff --git a/test/complex.jl b/test/complex.jl index 69cc22c985..04be8e4dac 100644 --- a/test/complex.jl +++ b/test/complex.jl @@ -1,4 +1,5 @@ using ModelingToolkit +using OrdinaryDiffEq using ModelingToolkit: t_nounits as t using Test @@ -14,3 +15,30 @@ using Test end @named mixed = ComplexModel() @test length(equations(mixed)) == 2 + +@testset "Complex ODEProblem" begin + using ModelingToolkit: t_nounits as t, D_nounits as D + + vars = @variables x(t) y(t) z(t) + pars = @parameters a b + + eqs = [ + D(x) ~ y - x, + D(y) ~ -x * z + b * abs(z), + D(z) ~ x * y - a + ] + @named modlorenz = System(eqs, t) + sys = structural_simplify(modlorenz) + + ic = ModelingToolkit.get_index_cache(sys) + @test ic.tunable_buffer_size.type == Number + + u0 = ComplexF64[-4.0, 5.0, 0.0] .+ randn(ComplexF64, 3) + p = ComplexF64[5.0, 0.1] + dict = merge(Dict(unknowns(sys) .=> u0), Dict(parameters(sys) .=> p)) + prob = ODEProblem(sys, dict, (0.0, 1.0)) + + sol = solve(prob, Tsit5(), saveat = 0.1) + + @test sol.u[1] isa Vector{ComplexF64} +end