Skip to content

Commit 2ae7183

Browse files
Merge pull request #2897 from AayushSabharwal/as/fix-remake
fix: fix remaking scalarized array parameters with non-scalarized symbolic map
2 parents 2c2e914 + 4ccd1d9 commit 2ae7183

File tree

2 files changed

+43
-3
lines changed

2 files changed

+43
-3
lines changed

src/systems/parameter_buffer.jl

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -506,15 +506,38 @@ function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, va
506506
@set! newbuf.nonnumeric = Tuple(Vector{Any}(undef, length(buf))
507507
for buf in newbuf.nonnumeric)
508508

509+
syms = collect(keys(vals))
510+
vals = Dict{Any, Any}(vals)
511+
for sym in syms
512+
symbolic_type(sym) == ArraySymbolic() || continue
513+
is_parameter(indp, sym) && continue
514+
stype = symtype(unwrap(sym))
515+
stype <: AbstractArray || continue
516+
Symbolics.shape(sym) == Symbolics.Unknown() && continue
517+
for i in eachindex(sym)
518+
vals[sym[i]] = vals[sym][i]
519+
end
520+
end
521+
509522
# If the parameter buffer is an `MTKParameters` object, `indp` must eventually drill
510523
# down to an `AbstractSystem` using `symbolic_container`. We leverage this to get
511524
# the index cache.
512525
ic = get_index_cache(indp_to_system(indp))
513526
for (p, val) in vals
514527
idx = parameter_index(indp, p)
515-
validate_parameter_type(ic, p, idx, val)
516-
_set_parameter_unchecked!(
517-
newbuf, val, idx; update_dependent = false)
528+
if idx !== nothing
529+
validate_parameter_type(ic, p, idx, val)
530+
_set_parameter_unchecked!(
531+
newbuf, val, idx; update_dependent = false)
532+
elseif symbolic_type(p) == ArraySymbolic()
533+
for (i, j) in zip(eachindex(p), eachindex(val))
534+
pi = p[i]
535+
idx = parameter_index(indp, pi)
536+
validate_parameter_type(ic, pi, idx, val[j])
537+
_set_parameter_unchecked!(
538+
newbuf, val[j], idx; update_dependent = false)
539+
end
540+
end
518541
end
519542

520543
@set! newbuf.tunable = narrow_buffer_type_and_fallback_undefs.(

test/mtkparameters.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,3 +290,20 @@ end
290290
sys = complete(sys)
291291
@test_throws ["Could not evaluate", "b", "Missing", "2c"] MTKParameters(sys, [a => 1.0])
292292
end
293+
294+
@testset "Issue#3804" begin
295+
@parameters k[1:4]
296+
@variables (V(t))[1:2]
297+
eqs = [
298+
D(V[1]) ~ k[1] - k[2] * V[1],
299+
D(V[2]) ~ k[3] - k[4] * V[2]
300+
]
301+
@mtkbuild osys_scal = ODESystem(eqs, t, [V[1], V[2]], [k[1], k[2], k[3], k[4]])
302+
303+
u0 = [V => [10.0, 20.0]]
304+
ps_vec = [k => [2.0, 3.0, 4.0, 5.0]]
305+
ps_scal = [k[1] => 1.0, k[2] => 2.0, k[3] => 3.0, k[4] => 4.0]
306+
oprob_scal_scal = ODEProblem(osys_scal, u0, 1.0, ps_scal)
307+
newoprob = remake(oprob_scal_scal; p = ps_vec)
308+
@test newoprob.ps[k] == [2.0, 3.0, 4.0, 5.0]
309+
end

0 commit comments

Comments
 (0)