Skip to content

Commit d5dfc4a

Browse files
feat: add missing_is_symbolic to remake
1 parent 33ed35b commit d5dfc4a

File tree

2 files changed

+83
-60
lines changed

2 files changed

+83
-60
lines changed

src/remake.jl

Lines changed: 49 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,22 @@ be interpreted as a symbolic map and used as-is for parameters. `use_defaults` a
7575
controlling whether the default values from the system will be used to calculate missing
7676
values in the symbolic map passed to `u0` or `p`. It is only valid when either `u0` or
7777
`p` have been explicitly provided as a symbolic map and the problem has an associated
78-
system.
78+
system. If `missing_is_symbolic`, not providing `u0`/`p` is interpreted as an empty symbolic
79+
map. Otherwise, they retain the existing value in `prob`. If neither `u0` nor `p` is provided,
80+
then they retain the existing values in `prob` regardless of `missing_is_symbolic`.
7981
"""
8082
function remake(prob::AbstractSciMLProblem; u0 = missing,
81-
p = missing, interpret_symbolicmap = true, use_defaults = false, kwargs...)
82-
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)
83+
p = missing, interpret_symbolicmap = true, use_defaults = false,
84+
missing_is_symbolic = true, kwargs...)
85+
u0, p = updated_u0_p(
86+
prob, u0, p; interpret_symbolicmap, use_defaults, missing_is_symbolic)
8387
_remake_internal(prob; kwargs..., u0, p)
8488
end
8589

8690
function remake(prob::AbstractIntervalNonlinearProblem; p = missing,
87-
interpret_symbolicmap = true, use_defaults = false, kwargs...)
88-
_, p = updated_u0_p(prob, [], p; interpret_symbolicmap, use_defaults)
91+
interpret_symbolicmap = true, use_defaults = false, missing_is_symbolic = true, kwargs...)
92+
_, p = updated_u0_p(
93+
prob, [], p; interpret_symbolicmap, use_defaults, missing_is_symbolic)
8994
_remake_internal(prob; kwargs..., p)
9095
end
9196

@@ -94,8 +99,10 @@ function remake(prob::AbstractNoiseProblem; kwargs...)
9499
end
95100

96101
function remake(
97-
prob::AbstractIntegralProblem; p = missing, interpret_symbolicmap = true, use_defaults = false, kwargs...)
98-
_, p = updated_u0_p(prob, nothing, p; interpret_symbolicmap, use_defaults)
102+
prob::AbstractIntegralProblem; p = missing, interpret_symbolicmap = true,
103+
use_defaults = false, missing_is_symbolic = true, kwargs...)
104+
_, p = updated_u0_p(
105+
prob, nothing, p; interpret_symbolicmap, use_defaults, missing_is_symbolic)
99106
_remake_internal(prob; kwargs..., p)
100107
end
101108

@@ -114,12 +121,14 @@ function remake(prob::ODEProblem; f = missing,
114121
interpret_symbolicmap = true,
115122
build_initializeprob = true,
116123
use_defaults = false,
124+
missing_is_symbolic = true,
117125
_kwargs...)
118126
if tspan === missing
119127
tspan = prob.tspan
120128
end
121129

122-
newu0, newp = updated_u0_p(prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults)
130+
newu0, newp = updated_u0_p(
131+
prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults, missing_is_symbolic)
123132

124133
iip = isinplace(prob)
125134

@@ -231,13 +240,15 @@ Remake the given `BVProblem`.
231240
"""
232241
function remake(prob::BVProblem{uType, tType, iip, nlls}; f = missing, bc = missing,
233242
u0 = missing, tspan = missing, p = missing, kwargs = missing, problem_type = missing,
234-
interpret_symbolicmap = true, use_defaults = false, _kwargs...) where {
243+
interpret_symbolicmap = true, use_defaults = false, missing_is_symbolic = true,
244+
_kwargs...) where {
235245
uType, tType, iip, nlls}
236246
if tspan === missing
237247
tspan = prob.tspan
238248
end
239249

240-
u0, p = updated_u0_p(prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults)
250+
u0, p = updated_u0_p(
251+
prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults, missing_is_symbolic)
241252

242253
if problem_type === missing
243254
problem_type = prob.problem_type
@@ -296,14 +307,16 @@ function remake(prob::SDEProblem;
296307
noise_rate_prototype = missing,
297308
interpret_symbolicmap = true,
298309
use_defaults = false,
310+
missing_is_symbolic = true,
299311
seed = missing,
300312
kwargs = missing,
301313
_kwargs...)
302314
if tspan === missing
303315
tspan = prob.tspan
304316
end
305317

306-
u0, p = updated_u0_p(prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults)
318+
u0, p = updated_u0_p(
319+
prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults, missing_is_symbolic)
307320

308321
if noise === missing
309322
noise = prob.noise
@@ -405,8 +418,10 @@ function remake(prob::OptimizationProblem;
405418
kwargs = missing,
406419
interpret_symbolicmap = true,
407420
use_defaults = false,
421+
missing_is_symbolic = true,
408422
_kwargs...)
409-
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)
423+
u0, p = updated_u0_p(
424+
prob, u0, p; interpret_symbolicmap, use_defaults, missing_is_symbolic)
410425
if f === missing
411426
f = prob.f
412427
end
@@ -447,7 +462,6 @@ end
447462
problem_type = missing, kwargs = missing, _kwargs...)
448463
449464
Remake the given `NonlinearProblem`.
450-
If `u0` or `p` are given as symbolic maps `ModelingToolkit.jl` has to be loaded.
451465
"""
452466
function remake(prob::NonlinearProblem;
453467
f = missing,
@@ -457,8 +471,10 @@ function remake(prob::NonlinearProblem;
457471
kwargs = missing,
458472
interpret_symbolicmap = true,
459473
use_defaults = false,
474+
missing_is_symbolic = true,
460475
_kwargs...)
461-
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)
476+
u0, p = updated_u0_p(
477+
prob, u0, p; interpret_symbolicmap, use_defaults, missing_is_symbolic)
462478
if f === missing
463479
f = prob.f
464480
end
@@ -483,8 +499,9 @@ end
483499
Remake the given `NonlinearLeastSquaresProblem`.
484500
"""
485501
function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p = missing,
486-
interpret_symbolicmap = true, use_defaults = false, kwargs = missing, _kwargs...)
487-
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)
502+
interpret_symbolicmap = true, use_defaults = false, missing_is_symbolic = true, kwargs = missing, _kwargs...)
503+
u0, p = updated_u0_p(
504+
prob, u0, p; interpret_symbolicmap, use_defaults, missing_is_symbolic)
488505

489506
if f === missing
490507
f = prob.f
@@ -520,35 +537,27 @@ anydict(d) = Dict{Any, Any}(d)
520537
anydict() = Dict{Any, Any}()
521538

522539
function _updated_u0_p_internal(
523-
prob, ::Missing, ::Missing, t0; interpret_symbolicmap = true, use_defaults = false)
540+
prob, ::Missing, ::Missing, t0; kwargs...)
524541
return state_values(prob), parameter_values(prob)
525542
end
526543
function _updated_u0_p_internal(
527-
prob, ::Missing, p, t0; interpret_symbolicmap = true, use_defaults = false)
528-
u0 = state_values(prob)
529-
530-
if p isa AbstractArray && isempty(p)
531-
return _updated_u0_p_internal(
532-
prob, u0, parameter_values(prob), t0; interpret_symbolicmap)
533-
end
534-
eltype(p) <: Pair && interpret_symbolicmap || return u0, p
535-
defs = default_values(prob)
536-
p = fill_p(prob, anydict(p); defs, use_defaults)
537-
return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true), t0)
544+
prob, ::Missing, p, t0; interpret_symbolicmap = true,
545+
use_defaults = false, missing_is_symbolic = true)
546+
return _updated_u0_p_internal(
547+
prob, missing_is_symbolic ? anydict() : state_values(prob),
548+
p, t0; interpret_symbolicmap, use_defaults)
538549
end
539550

540551
function _updated_u0_p_internal(
541-
prob, u0, ::Missing, t0; interpret_symbolicmap = true, use_defaults = false)
542-
p = parameter_values(prob)
543-
544-
eltype(u0) <: Pair || return u0, p
545-
defs = default_values(prob)
546-
u0 = fill_u0(prob, anydict(u0); defs, use_defaults)
547-
return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false), t0)
552+
prob, u0, ::Missing, t0; interpret_symbolicmap = true,
553+
use_defaults = false, missing_is_symbolic = true)
554+
return _updated_u0_p_internal(
555+
prob, u0, missing_is_symbolic ? anydict() : parameter_values(prob),
556+
t0; interpret_symbolicmap, use_defaults)
548557
end
549558

550559
function _updated_u0_p_internal(
551-
prob, u0, p, t0; interpret_symbolicmap = true, use_defaults = false)
560+
prob, u0, p, t0; interpret_symbolicmap = true, use_defaults = false, kwargs...)
552561
isu0symbolic = eltype(u0) <: Pair
553562
ispsymbolic = eltype(p) <: Pair && interpret_symbolicmap
554563

@@ -727,7 +736,8 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}, t0)
727736
end
728737

729738
function updated_u0_p(
730-
prob, u0, p, t0 = nothing; interpret_symbolicmap = true, use_defaults = false)
739+
prob, u0, p, t0 = nothing; interpret_symbolicmap = true,
740+
use_defaults = false, missing_is_symbolic = true)
731741
if u0 === missing && p === missing
732742
return state_values(prob), parameter_values(prob)
733743
end
@@ -746,7 +756,8 @@ function updated_u0_p(
746756
return (u0 === missing ? state_values(prob) : u0),
747757
(p === missing ? parameter_values(prob) : p)
748758
end
749-
return _updated_u0_p_internal(prob, u0, p, t0; interpret_symbolicmap, use_defaults)
759+
return _updated_u0_p_internal(
760+
prob, u0, p, t0; interpret_symbolicmap, use_defaults, missing_is_symbolic)
750761
end
751762

752763
# overloaded in MTK to intercept symbolic remake

test/remake_tests.jl

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -121,21 +121,25 @@ for prob in deepcopy(probs)
121121
@test prob2.p == typeof(prob.p)(p)
122122

123123
# respect defaults (:x), fallback to existing value (:z)
124-
prob2 = @inferred baseType remake(prob; u0 = [:y => 0.2], use_defaults = true)
124+
prob2 = @inferred baseType remake(
125+
prob; u0 = [:y => 0.2], use_defaults = true, missing_is_symbolic = false)
125126
@test prob2.u0 [0.1, 0.2, 3.0]
126127
@test prob2.p == typeof(prob.p)(p) # params unaffected
128+
prob2 = @inferred baseType remake(prob; u0 = [:y => 0.2], use_defaults = true)
129+
@test prob2.u0 [0.1, 0.2, 3.0]
130+
@test all(prob2.p .≈ [0.1, 20.0, 30.0])
127131

128132
# override defaults
129133
prob2 = @inferred baseType remake(prob; u0 = [:x => 0.2], use_defaults = true)
130134
@test prob2.u0 [0.2, 2.0, 3.0]
131-
@test prob2.p == typeof(prob.p)(p)
135+
@test all(prob2.p .≈ [0.1, 20.0, 30.0])
132136

133137
prob2 = @inferred baseType remake(prob; p = [:b => 0.2], use_defaults = true)
134-
@test prob2.u0 == u0
138+
@test prob2.u0 [0.1, 2.0, 3.0]
135139
@test all(prob2.p .≈ [0.1, 0.2, 30.0])
136140

137141
prob2 = @inferred baseType remake(prob; p = [:a => 0.2], use_defaults = true)
138-
@test prob2.u0 == u0
142+
@test prob2.u0 [0.1, 2.0, 3.0]
139143
@test all(prob2.p .≈ [0.2, 20.0, 30.0])
140144

141145
empty!(prob.f.sys.defaults)
@@ -152,36 +156,37 @@ for prob in deepcopy(probs)
152156
@test prob2.u0 == u0
153157
@test prob2.p == typeof(prob.p)(p)
154158

155-
prob2 = @inferred baseType remake(prob; u0 = [:x => 0.2])
159+
prob2 = @inferred baseType remake(
160+
prob; u0 = [:x => 0.2], missing_is_symbolic = false)
156161
@test prob2.u0 [0.2, 0.6, 3.0]
157162
@test prob2.p == typeof(prob.p)(p)
158163

159164
# respect numeric defaults (:z)
160165
prob2 = @inferred baseType remake(prob; u0 = [:x => 0.2], use_defaults = true)
161166
@test prob2.u0 [0.2, 0.6, 9.0]
162-
@test prob2.p == typeof(prob.p)(p) # params unaffected
167+
@test all(prob2.p .≈ [10.0, 30.0, 0.9])
163168

164169
# override defaults
165170
prob2 = @inferred baseType remake(prob; u0 = [:y => 0.2])
166171
@test prob2.u0 [1.0, 0.2, 3.0]
167-
@test prob2.p == typeof(prob.p)(p)
172+
@test all(prob2.p .≈ [10.0, 30.0, 30.0])
168173
prob2 = @inferred baseType remake(prob; u0 = [:y => 0.2], use_defaults = true)
169174
@test prob2.u0 [1.0, 0.2, 9.0]
170-
@test prob2.p == typeof(prob.p)(p)
175+
@test all(prob2.p .≈ [10.0, 30.0, 0.9])
171176

172177
prob2 = @inferred baseType remake(prob; p = [:a => 0.2])
173-
@test prob2.u0 == u0
178+
@test prob2.u0 [1.0, 3.0, 3.0]
174179
@test all(prob2.p .≈ [0.2, 0.6, 30.0])
175180

176181
prob2 = @inferred baseType remake(prob; p = [:a => 0.2], use_defaults = true)
177-
@test prob2.u0 == u0
182+
@test prob2.u0 [1.0, 3.0, 9.0]
178183
@test all(prob2.p .≈ [0.2, 0.6, 0.9])
179184

180185
prob2 = @inferred baseType remake(prob; p = [:b => 0.2])
181-
@test prob2.u0 == u0
186+
@test prob2.u0 [1.0, 3.0, 3.0]
182187
@test all(prob2.p .≈ [10.0, 0.2, 30.0])
183188
prob2 = @inferred baseType remake(prob; p = [:b => 0.2], use_defaults = true)
184-
@test prob2.u0 == u0
189+
@test prob2.u0 [1.0, 3.0, 9.0]
185190
@test all(prob2.p .≈ [10.0, 0.2, 0.9])
186191

187192
empty!(prob.f.sys.defaults)
@@ -198,47 +203,54 @@ for prob in deepcopy(probs)
198203
@test prob2.u0 == u0
199204
@test prob2.p == typeof(prob.p)(p)
200205

201-
# Dependency ignored since `p` was not changed
202-
prob2 = @inferred baseType remake(prob; u0 = [:x => 0.2])
206+
# Dependency ignored since `p` was not changed and `missing_is_symbolic = false`
207+
prob2 = @inferred baseType remake(
208+
prob; u0 = [:x => 0.2], missing_is_symbolic = false)
203209
@test prob2.u0 [0.2, 30.0, 3.0]
204210
@test prob2.p == typeof(prob.p)(p)
205211

206-
# need to pass empty `Dict()` to prevent defaulting to existing values
212+
# with `missing_is_symbolic = true`
213+
prob2 = @inferred baseType remake(
214+
prob; u0 = [:x => 0.2])
215+
@test prob2.u0 [0.2, 30.0, 3.0]
216+
@test all(prob2.p .≈ [10.0, 0.6, 30.0])
217+
218+
# above is identical to this
207219
prob2 = @inferred baseType remake(
208220
prob; u0 = [:x => 0.2], p = Dict())
209221
@test prob2.u0 [0.2, 30.0, 3.0]
210222
@test all(prob2.p .≈ [10.0, 0.6, 30.0])
211223

212224
prob2 = @inferred baseType remake(
213-
prob; u0 = [:x => 0.2], p = Dict(), use_defaults = true)
225+
prob; u0 = [:x => 0.2], use_defaults = true)
214226
@test prob2.u0 [0.2, 3.0, 3.0]
215227
@test all(prob2.p .≈ [1.0, 0.6, 30.0])
216228

217229
# override defaults
218230
prob2 = @inferred baseType remake(
219-
prob; u0 = [:y => 0.2], p = Dict())
231+
prob; u0 = [:y => 0.2])
220232
@test prob2.u0 [1.0, 0.2, 3.0]
221233
@test all(prob2.p .≈ [10.0, 3.0, 30.0])
222234
prob2 = @inferred baseType remake(
223-
prob; u0 = [:y => 0.2], p = Dict(), use_defaults = true)
235+
prob; u0 = [:y => 0.2], use_defaults = true)
224236
@test prob2.u0 [0.1, 0.2, 3.0]
225237
@test all(prob2.p .≈ [1.0, 0.3, 30.0])
226238

227239
prob2 = @inferred baseType remake(
228-
prob; p = [:a => 0.2], u0 = Dict())
240+
prob; p = [:a => 0.2])
229241
@test prob2.u0 [1.0, 0.6, 3.0]
230242
@test all(prob2.p .≈ [0.2, 3.0, 30.0])
231243
prob2 = @inferred baseType remake(
232-
prob; p = [:a => 0.2], u0 = Dict(), use_defaults = true)
244+
prob; p = [:a => 0.2], use_defaults = true)
233245
@test prob2.u0 [0.1, 0.6, 3.0]
234246
@test all(prob2.p .≈ [0.2, 0.3, 30.0])
235247

236248
prob2 = @inferred baseType remake(
237-
prob; p = [:b => 0.2], u0 = Dict())
249+
prob; p = [:b => 0.2])
238250
@test prob2.u0 [1.0, 30.0, 3.0]
239251
@test all(prob2.p .≈ [10.0, 0.2, 30.0])
240252
prob2 = @inferred baseType remake(
241-
prob; p = [:b => 0.2], u0 = Dict(), use_defaults = true)
253+
prob; p = [:b => 0.2], use_defaults = true)
242254
@test prob2.u0 [0.1, 3.0, 3.0]
243255
@test all(prob2.p .≈ [1.0, 0.2, 30.0])
244256

0 commit comments

Comments
 (0)