Skip to content

Commit a4d64e6

Browse files
feat: add missing_is_symbolic to remake
1 parent adb8af8 commit a4d64e6

File tree

2 files changed

+62
-46
lines changed

2 files changed

+62
-46
lines changed

src/remake.jl

Lines changed: 45 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,22 @@ be interpreted as a symbolic map and used as-is for parameters. `use_defaults` a
7575
controlling whether the default values from the system will be used to calculate missing
7676
values in the symbolic map passed to `u0` or `p`. It is only valid when either `u0` or
7777
`p` have been explicitly provided as a symbolic map and the problem has an associated
78-
system.
78+
system. If `missing_is_symbolic`, not providing `u0`/`p` is interpreted as an empty symbolic
79+
map. Otherwise, they retain the existing value in `prob`. If neither `u0` nor `p` is provided,
80+
then they retain the existing values in `prob` regardless of `missing_is_symbolic`.
7981
"""
8082
function remake(prob::AbstractSciMLProblem; u0 = missing,
81-
p = missing, interpret_symbolicmap = true, use_defaults = false, kwargs...)
82-
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)
83+
p = missing, interpret_symbolicmap = true, use_defaults = false,
84+
missing_is_symbolic = true, kwargs...)
85+
u0, p = updated_u0_p(
86+
prob, u0, p; interpret_symbolicmap, use_defaults, missing_is_symbolic)
8387
_remake_internal(prob; kwargs..., u0, p)
8488
end
8589

8690
function remake(prob::AbstractIntervalNonlinearProblem; p = missing,
87-
interpret_symbolicmap = true, use_defaults = false, kwargs...)
88-
_, p = updated_u0_p(prob, [], p; interpret_symbolicmap, use_defaults)
91+
interpret_symbolicmap = true, use_defaults = false, missing_is_symbolic = true, kwargs...)
92+
_, p = updated_u0_p(
93+
prob, [], p; interpret_symbolicmap, use_defaults, missing_is_symbolic)
8994
_remake_internal(prob; kwargs..., p)
9095
end
9196

@@ -94,8 +99,10 @@ function remake(prob::AbstractNoiseProblem; kwargs...)
9499
end
95100

96101
function remake(
97-
prob::AbstractIntegralProblem; p = missing, interpret_symbolicmap = true, use_defaults = false, kwargs...)
98-
_, p = updated_u0_p(prob, nothing, p; interpret_symbolicmap, use_defaults)
102+
prob::AbstractIntegralProblem; p = missing, interpret_symbolicmap = true,
103+
use_defaults = false, missing_is_symbolic = true, kwargs...)
104+
_, p = updated_u0_p(
105+
prob, nothing, p; interpret_symbolicmap, use_defaults, missing_is_symbolic)
99106
_remake_internal(prob; kwargs..., p)
100107
end
101108

@@ -114,12 +121,14 @@ function remake(prob::ODEProblem; f = missing,
114121
interpret_symbolicmap = true,
115122
build_initializeprob = true,
116123
use_defaults = false,
124+
missing_is_symbolic = true,
117125
_kwargs...)
118126
if tspan === missing
119127
tspan = prob.tspan
120128
end
121129

122-
newu0, newp = updated_u0_p(prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults)
130+
newu0, newp = updated_u0_p(
131+
prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults, missing_is_symbolic)
123132

124133
iip = isinplace(prob)
125134

@@ -231,7 +240,8 @@ Remake the given `BVProblem`.
231240
"""
232241
function remake(prob::BVProblem{uType, tType, iip, nlls}; f = missing, bc = missing,
233242
u0 = missing, tspan = missing, p = missing, kwargs = missing, problem_type = missing,
234-
interpret_symbolicmap = true, use_defaults = false, _kwargs...) where {
243+
interpret_symbolicmap = true, use_defaults = false, missing_is_symbolic = true,
244+
_kwargs...) where {
235245
uType, tType, iip, nlls}
236246
if tspan === missing
237247
tspan = prob.tspan
@@ -296,14 +306,16 @@ function remake(prob::SDEProblem;
296306
noise_rate_prototype = missing,
297307
interpret_symbolicmap = true,
298308
use_defaults = false,
309+
missing_is_symbolic = true,
299310
seed = missing,
300311
kwargs = missing,
301312
_kwargs...)
302313
if tspan === missing
303314
tspan = prob.tspan
304315
end
305316

306-
u0, p = updated_u0_p(prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults)
317+
u0, p = updated_u0_p(
318+
prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults, missing_is_symbolic)
307319

308320
if noise === missing
309321
noise = prob.noise
@@ -405,8 +417,10 @@ function remake(prob::OptimizationProblem;
405417
kwargs = missing,
406418
interpret_symbolicmap = true,
407419
use_defaults = false,
420+
missing_is_symbolic = true,
408421
_kwargs...)
409-
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)
422+
u0, p = updated_u0_p(
423+
prob, u0, p; interpret_symbolicmap, use_defaults, missing_is_symbolic)
410424
if f === missing
411425
f = prob.f
412426
end
@@ -447,7 +461,6 @@ end
447461
problem_type = missing, kwargs = missing, _kwargs...)
448462
449463
Remake the given `NonlinearProblem`.
450-
If `u0` or `p` are given as symbolic maps `ModelingToolkit.jl` has to be loaded.
451464
"""
452465
function remake(prob::NonlinearProblem;
453466
f = missing,
@@ -457,8 +470,10 @@ function remake(prob::NonlinearProblem;
457470
kwargs = missing,
458471
interpret_symbolicmap = true,
459472
use_defaults = false,
473+
missing_is_symbolic = true,
460474
_kwargs...)
461-
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)
475+
u0, p = updated_u0_p(
476+
prob, u0, p; interpret_symbolicmap, use_defaults, missing_is_symbolic)
462477
if f === missing
463478
f = prob.f
464479
end
@@ -483,8 +498,9 @@ end
483498
Remake the given `NonlinearLeastSquaresProblem`.
484499
"""
485500
function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p = missing,
486-
interpret_symbolicmap = true, use_defaults = false, kwargs = missing, _kwargs...)
487-
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)
501+
interpret_symbolicmap = true, use_defaults = false, missing_is_symbolic = true, kwargs = missing, _kwargs...)
502+
u0, p = updated_u0_p(
503+
prob, u0, p; interpret_symbolicmap, use_defaults, missing_is_symbolic)
488504

489505
if f === missing
490506
f = prob.f
@@ -520,31 +536,22 @@ anydict(d) = Dict{Any, Any}(d)
520536
anydict() = Dict{Any, Any}()
521537

522538
function _updated_u0_p_internal(
523-
prob, ::Missing, ::Missing, t0; interpret_symbolicmap = true, use_defaults = false)
539+
prob, ::Missing, ::Missing, t0; kwargs...)
524540
return state_values(prob), parameter_values(prob)
525541
end
526542
function _updated_u0_p_internal(
527-
prob, ::Missing, p, t0; interpret_symbolicmap = true, use_defaults = false)
528-
u0 = state_values(prob)
529-
530-
if p isa AbstractArray && isempty(p)
531-
return _updated_u0_p_internal(
532-
prob, u0, parameter_values(prob), t0; interpret_symbolicmap)
533-
end
534-
eltype(p) <: Pair && interpret_symbolicmap || return u0, p
535-
defs = default_values(prob)
536-
p = fill_p(prob, anydict(p); defs, use_defaults)
537-
return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true), t0)
543+
prob, ::Missing, p, t0; interpret_symbolicmap = true,
544+
use_defaults = false, missing_is_symbolic = true)
545+
return _updated_u0_p_internal(
546+
prob, missing_is_symbolic ? anydict() : state_values(prob), p, t0; interpret_symbolicmap)
538547
end
539548

540549
function _updated_u0_p_internal(
541-
prob, u0, ::Missing, t0; interpret_symbolicmap = true, use_defaults = false)
542-
p = parameter_values(prob)
543-
544-
eltype(u0) <: Pair || return u0, p
545-
defs = default_values(prob)
546-
u0 = fill_u0(prob, anydict(u0); defs, use_defaults)
547-
return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false), t0)
550+
prob, u0, ::Missing, t0; interpret_symbolicmap = true,
551+
use_defaults = false, missing_is_symbolic = true)
552+
return _updated_u0_p_internal(
553+
prob, u0, missing_is_symbolic ? anydict() : parameter_values(prob),
554+
t0; interpret_symbolicmap)
548555
end
549556

550557
function _updated_u0_p_internal(
@@ -725,7 +732,8 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}, t0)
725732
end
726733

727734
function updated_u0_p(
728-
prob, u0, p, t0 = nothing; interpret_symbolicmap = true, use_defaults = false)
735+
prob, u0, p, t0 = nothing; interpret_symbolicmap = true,
736+
use_defaults = false, missing_is_symbolic = true)
729737
if u0 === missing && p === missing
730738
return state_values(prob), parameter_values(prob)
731739
end
@@ -744,7 +752,8 @@ function updated_u0_p(
744752
return (u0 === missing ? state_values(prob) : u0),
745753
(p === missing ? parameter_values(prob) : p)
746754
end
747-
return _updated_u0_p_internal(prob, u0, p, t0; interpret_symbolicmap, use_defaults)
755+
return _updated_u0_p_internal(
756+
prob, u0, p, t0; interpret_symbolicmap, use_defaults, missing_is_symbolic)
748757
end
749758

750759
# overloaded in MTK to intercept symbolic remake

test/remake_tests.jl

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -198,47 +198,54 @@ for prob in deepcopy(probs)
198198
@test prob2.u0 == u0
199199
@test prob2.p == typeof(prob.p)(p)
200200

201-
# Dependency ignored since `p` was not changed
202-
prob2 = @inferred baseType remake(prob; u0 = [:x => 0.2])
201+
# Dependency ignored since `p` was not changed and `missing_is_symbolic = false`
202+
prob2 = @inferred baseType remake(
203+
prob; u0 = [:x => 0.2], missing_is_symbolic = false)
203204
@test prob2.u0 [0.2, 30.0, 3.0]
204205
@test prob2.p == typeof(prob.p)(p)
205206

206-
# need to pass empty `Dict()` to prevent defaulting to existing values
207+
# with `missing_is_symbolic = true`
208+
prob2 = @inferred baseType remake(
209+
prob; u0 = [:x => 0.2])
210+
@test prob2.u0 [0.2, 30.0, 3.0]
211+
@test all(prob2.p .≈ [10.0, 0.6, 30.0])
212+
213+
# above is identical to this
207214
prob2 = @inferred baseType remake(
208215
prob; u0 = [:x => 0.2], p = Dict())
209216
@test prob2.u0 [0.2, 30.0, 3.0]
210217
@test all(prob2.p .≈ [10.0, 0.6, 30.0])
211218

212219
prob2 = @inferred baseType remake(
213-
prob; u0 = [:x => 0.2], p = Dict(), use_defaults = true)
220+
prob; u0 = [:x => 0.2], use_defaults = true)
214221
@test prob2.u0 [0.2, 3.0, 3.0]
215222
@test all(prob2.p .≈ [1.0, 0.6, 30.0])
216223

217224
# override defaults
218225
prob2 = @inferred baseType remake(
219-
prob; u0 = [:y => 0.2], p = Dict())
226+
prob; u0 = [:y => 0.2])
220227
@test prob2.u0 [1.0, 0.2, 3.0]
221228
@test all(prob2.p .≈ [10.0, 3.0, 30.0])
222229
prob2 = @inferred baseType remake(
223-
prob; u0 = [:y => 0.2], p = Dict(), use_defaults = true)
230+
prob; u0 = [:y => 0.2], use_defaults = true)
224231
@test prob2.u0 [0.1, 0.2, 3.0]
225232
@test all(prob2.p .≈ [1.0, 0.3, 30.0])
226233

227234
prob2 = @inferred baseType remake(
228-
prob; p = [:a => 0.2], u0 = Dict())
235+
prob; p = [:a => 0.2])
229236
@test prob2.u0 [1.0, 0.6, 3.0]
230237
@test all(prob2.p .≈ [0.2, 3.0, 30.0])
231238
prob2 = @inferred baseType remake(
232-
prob; p = [:a => 0.2], u0 = Dict(), use_defaults = true)
239+
prob; p = [:a => 0.2], use_defaults = true)
233240
@test prob2.u0 [0.1, 0.6, 3.0]
234241
@test all(prob2.p .≈ [0.2, 0.3, 30.0])
235242

236243
prob2 = @inferred baseType remake(
237-
prob; p = [:b => 0.2], u0 = Dict())
244+
prob; p = [:b => 0.2])
238245
@test prob2.u0 [1.0, 30.0, 3.0]
239246
@test all(prob2.p .≈ [10.0, 0.2, 30.0])
240247
prob2 = @inferred baseType remake(
241-
prob; p = [:b => 0.2], u0 = Dict(), use_defaults = true)
248+
prob; p = [:b => 0.2], use_defaults = true)
242249
@test prob2.u0 [0.1, 3.0, 3.0]
243250
@test all(prob2.p .≈ [1.0, 0.2, 30.0])
244251

0 commit comments

Comments
 (0)