Skip to content

Commit 063379f

Browse files
fix: fix type promotion in late_binding_update_u0_p with non-dual types
1 parent 2035e73 commit 063379f

File tree

2 files changed

+41
-13
lines changed

2 files changed

+41
-13
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -638,24 +638,42 @@ function SciMLBase.remake_initialization_data(
638638
return SciMLBase.remake_initialization_data(sys, odefn, newu0, t0, newp, newu0, newp)
639639
end
640640

641-
function promote_u0_p(u0, p::MTKParameters, t0)
642-
u0 = DiffEqBase.promote_u0(u0, p.tunable, t0)
643-
u0 = DiffEqBase.promote_u0(u0, p.initials, t0)
641+
promote_type_with_nothing(::Type{T}, ::Nothing) where {T} = T
642+
promote_type_with_nothing(::Type{T}, ::SizedVector{0}) where {T} = T
643+
function promote_type_with_nothing(::Type{T}, ::AbstractArray{T2}) where {T, T2}
644+
promote_type(T, T2)
645+
end
644646

645-
if !isempty(p.tunable)
646-
tunables = DiffEqBase.promote_u0(p.tunable, u0, t0)
647-
p = SciMLStructures.replace(SciMLStructures.Tunable(), p, tunables)
648-
end
649-
if !isempty(p.initials)
650-
initials = DiffEqBase.promote_u0(p.initials, u0, t0)
651-
p = SciMLStructures.replace(SciMLStructures.Initials(), p, initials)
647+
promote_with_nothing(::Type, ::Nothing) = nothing
648+
promote_with_nothing(::Type, x::SizedVector{0}) = x
649+
promote_with_nothing(::Type{T}, x::AbstractArray{T}) where {T} = x
650+
function promote_with_nothing(::Type{T}, x::AbstractArray{T2}) where {T, T2}
651+
if ArrayInterface.ismutable(x)
652+
y = similar(x, T)
653+
copyto!(y, x)
654+
return y
655+
else
656+
yT = similar_type(x, T)
657+
return yT(x)
652658
end
653-
654-
return u0, p
659+
end
660+
function promote_with_nothing(::Type{T}, p::MTKParameters) where {T}
661+
tunables = promote_with_nothing(T, p.tunable)
662+
p = SciMLStructures.replace(SciMLStructures.Tunable(), p, tunables)
663+
initials = promote_with_nothing(T, p.initials)
664+
p = SciMLStructures.replace(SciMLStructures.Initials(), p, initials)
665+
return p
655666
end
656667

657668
function promote_u0_p(u0, p, t0)
658-
return DiffEqBase.promote_u0(u0, p, t0), DiffEqBase.promote_u0(p, u0, t0)
669+
T = Union{}
670+
T = promote_type_with_nothing(T, u0)
671+
T = promote_type_with_nothing(T, p.tunable)
672+
T = promote_type_with_nothing(T, p.initials)
673+
674+
u0 = promote_with_nothing(T, u0)
675+
p = promote_with_nothing(T, p)
676+
return u0, p
659677
end
660678

661679
function SciMLBase.late_binding_update_u0_p(

test/initial_values.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,3 +345,13 @@ end
345345
@test state_values(initdata.initializeprob) isa SVector
346346
@test parameter_values(initdata.initializeprob) isa SVector
347347
end
348+
349+
@testset "Type promotion of `p` works with non-dual types" begin
350+
@variables x(t) y(t)
351+
@mtkbuild sys = ODESystem([D(x) ~ x + y, x^3 + y^3 ~ 5], t; guesses = [y => 1.0])
352+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0))
353+
prob2 = remake(prob; u0 = BigFloat.(prob.u0))
354+
@test prob2.p.initials isa Vector{BigFloat}
355+
sol = solve(prob2)
356+
@test SciMLBase.successful_retcode(sol)
357+
end

0 commit comments

Comments
 (0)