Skip to content

Commit fa14fdd

Browse files
Merge pull request #3156 from AayushSabharwal/as/varmap-partial-eval
fix: only evaluate required keys of varmap in `process_SciMLProblem`
2 parents a0fe7c3 + 7aba2fc commit fa14fdd

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

src/systems/problem_utils.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -301,12 +301,14 @@ end
301301
"""
302302
$(TYPEDSIGNATURES)
303303
304-
Performs symbolic substitution on the values in `varmap`, using `varmap` itself as the
305-
set of substitution rules.
306-
"""
307-
function evaluate_varmap!(varmap::AbstractDict)
308-
for (k, v) in varmap
309-
varmap[k] = fixpoint_sub(v, varmap)
304+
Performs symbolic substitution on the values in `varmap` for the keys in `vars`, using
305+
`varmap` itself as the set of substitution rules. If an entry in `vars` is not a key
306+
in `varmap`, it is ignored.
307+
"""
308+
function evaluate_varmap!(varmap::AbstractDict, vars)
309+
for k in vars
310+
haskey(varmap, k) || continue
311+
varmap[k] = fixpoint_sub(varmap[k], varmap)
310312
end
311313
end
312314

@@ -499,7 +501,7 @@ function process_SciMLProblem(
499501
add_observed!(sys, op)
500502
add_parameter_dependencies!(sys, op)
501503

502-
evaluate_varmap!(op)
504+
evaluate_varmap!(op, dvs)
503505

504506
u0 = better_varmap_to_vars(
505507
op, dvs; tofloat = true, use_union = false,
@@ -511,6 +513,7 @@ function process_SciMLProblem(
511513

512514
check_eqs_u0(eqs, dvs, u0; check_length, kwargs...)
513515

516+
evaluate_varmap!(op, ps)
514517
if is_split(sys)
515518
p = MTKParameters(sys, op)
516519
else

test/initial_values.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,16 @@ end
136136
prob = @test_nowarn ODEProblem(sys, [], (0.0, 1.0))
137137
@test prob.p isa Vector{Float64}
138138
end
139+
140+
@testset "Issue#3153" begin
141+
@variables x(t) y(t)
142+
@parameters c1 c2 c3
143+
eqs = [D(x) ~ y,
144+
y ~ ifelse(t < c1, 0.0, (-c1 + t)^(c3))]
145+
sps = [x, y]
146+
ps = [c1, c2, c3]
147+
@mtkbuild osys = ODESystem(eqs, t, sps, ps)
148+
u0map = [x => 1.0]
149+
pmap = [c1 => 5.0, c2 => 1.0, c3 => 1.2]
150+
oprob = ODEProblem(osys, u0map, (0.0, 10.0), pmap)
151+
end

0 commit comments

Comments
 (0)