Skip to content

Commit 2682251

Browse files
refactor: support new remake_buffer signature
1 parent 1e4446a commit 2682251

File tree

1 file changed

+28
-27
lines changed

1 file changed

+28
-27
lines changed

src/systems/parameter_buffer.jl

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -505,46 +505,47 @@ function indp_to_system(indp)
505505
return indp
506506
end
507507

508-
function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, vals::Dict)
508+
function validate_and_update(ic::IndexCache, buffer::MTKParameters, sym, idx, val)
509+
if sym !== nothing
510+
validate_parameter_type(ic, sym, idx, val)
511+
end
512+
_set_parameter_unchecked!(buffer, val, idx; update_dependent = false)
513+
end
514+
515+
function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, idxs, vals)
509516
newbuf = @set oldbuf.tunable = Vector{Any}(undef, length(oldbuf.tunable))
510517
@set! newbuf.discrete = Tuple(similar(buf, Any) for buf in newbuf.discrete)
511518
@set! newbuf.constant = Tuple(Vector{Any}(undef, length(buf))
512519
for buf in newbuf.constant)
513520
@set! newbuf.nonnumeric = Tuple(Vector{Any}(undef, length(buf))
514521
for buf in newbuf.nonnumeric)
515522

516-
syms = collect(keys(vals))
517-
vals = Dict{Any, Any}(vals)
518-
for sym in syms
519-
symbolic_type(sym) == ArraySymbolic() || continue
520-
is_parameter(indp, sym) && continue
521-
stype = symtype(unwrap(sym))
522-
stype <: AbstractArray || continue
523-
Symbolics.shape(sym) == Symbolics.Unknown() && continue
524-
for i in eachindex(sym)
525-
vals[sym[i]] = vals[sym][i]
526-
end
527-
end
528-
529523
# If the parameter buffer is an `MTKParameters` object, `indp` must eventually drill
530524
# down to an `AbstractSystem` using `symbolic_container`. We leverage this to get
531525
# the index cache.
532526
ic = get_index_cache(indp_to_system(indp))
533-
for (p, val) in vals
534-
idx = parameter_index(indp, p)
535-
if idx !== nothing
536-
validate_parameter_type(ic, p, idx, val)
537-
_set_parameter_unchecked!(
538-
newbuf, val, idx; update_dependent = false)
539-
elseif symbolic_type(p) == ArraySymbolic()
540-
for (i, j) in zip(eachindex(p), eachindex(val))
541-
pi = p[i]
542-
idx = parameter_index(indp, pi)
543-
validate_parameter_type(ic, pi, idx, val[j])
544-
_set_parameter_unchecked!(
545-
newbuf, val[j], idx; update_dependent = false)
527+
for (idx, val) in zip(idxs, vals)
528+
if val === missing
529+
val = get_temporary_value(idx)
530+
end
531+
if !(idx isa ParameterIndex)
532+
p = unwrap(idx)
533+
idx = parameter_index(ic, p)
534+
stype = symtype(p)
535+
if idx === nothing && symbolic_type(p) == ArraySymbolic() &&
536+
stype <: AbstractArray && Symbolics.shape(p) != Symbolics.Unknown()
537+
for i in eachindex(p)
538+
sym = p[i]
539+
subidx = parameter_index(ic, sym)
540+
subval = val[i]
541+
validate_and_update(ic, newbuf, sym, subidx, subval)
542+
end
543+
continue
546544
end
545+
else
546+
p = nothing
547547
end
548+
validate_and_update(ic, newbuf, p, idx, val)
548549
end
549550

550551
@set! newbuf.tunable = narrow_buffer_type_and_fallback_undefs(

0 commit comments

Comments
 (0)