Skip to content

Commit c1b5d83

Browse files
Merge pull request #2633 from AayushSabharwal/as/partial-remake-buffer
feat: support partial updates in `remake_buffer`
2 parents e71c417 + df0ce79 commit c1b5d83

File tree

2 files changed

+34
-12
lines changed

2 files changed

+34
-12
lines changed

src/systems/parameter_buffer.jl

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -324,29 +324,42 @@ function _set_parameter_unchecked!(
324324
p.dependent_update_iip(ArrayPartition(p.dependent), p...)
325325
end
326326

327-
function narrow_buffer_type(buffer::Vector)
327+
function narrow_buffer_type_and_fallback_undefs(oldbuf::Vector, newbuf::Vector)
328328
type = Union{}
329-
for x in buffer
330-
type = Union{type, typeof(x)}
329+
for i in eachindex(newbuf)
330+
isassigned(newbuf, i) || continue
331+
type = promote_type(type, typeof(newbuf[i]))
331332
end
332-
return convert(Vector{type}, buffer)
333+
for i in eachindex(newbuf)
334+
isassigned(newbuf, i) && continue
335+
newbuf[i] = convert(type, oldbuf[i])
336+
end
337+
return convert(Vector{type}, newbuf)
333338
end
334339

335340
function SymbolicIndexingInterface.remake_buffer(sys, oldbuf::MTKParameters, vals::Dict)
336-
newbuf = @set oldbuf.tunable = similar.(oldbuf.tunable, Any)
337-
@set! newbuf.discrete = similar.(newbuf.discrete, Any)
338-
@set! newbuf.constant = similar.(newbuf.constant, Any)
339-
@set! newbuf.nonnumeric = similar.(newbuf.nonnumeric, Any)
341+
newbuf = @set oldbuf.tunable = Tuple(Vector{Any}(undef, length(buf))
342+
for buf in oldbuf.tunable)
343+
@set! newbuf.discrete = Tuple(Vector{Any}(undef, length(buf))
344+
for buf in newbuf.discrete)
345+
@set! newbuf.constant = Tuple(Vector{Any}(undef, length(buf))
346+
for buf in newbuf.constant)
347+
@set! newbuf.nonnumeric = Tuple(Vector{Any}(undef, length(buf))
348+
for buf in newbuf.nonnumeric)
340349

341350
for (p, val) in vals
342351
_set_parameter_unchecked!(
343352
newbuf, val, parameter_index(sys, p); update_dependent = false)
344353
end
345354

346-
@set! newbuf.tunable = narrow_buffer_type.(newbuf.tunable)
347-
@set! newbuf.discrete = narrow_buffer_type.(newbuf.discrete)
348-
@set! newbuf.constant = narrow_buffer_type.(newbuf.constant)
349-
@set! newbuf.nonnumeric = narrow_buffer_type.(newbuf.nonnumeric)
355+
@set! newbuf.tunable = narrow_buffer_type_and_fallback_undefs.(
356+
oldbuf.tunable, newbuf.tunable)
357+
@set! newbuf.discrete = narrow_buffer_type_and_fallback_undefs.(
358+
oldbuf.discrete, newbuf.discrete)
359+
@set! newbuf.constant = narrow_buffer_type_and_fallback_undefs.(
360+
oldbuf.constant, newbuf.constant)
361+
@set! newbuf.nonnumeric = narrow_buffer_type_and_fallback_undefs.(
362+
oldbuf.nonnumeric, newbuf.nonnumeric)
350363
if newbuf.dependent_update_oop !== nothing
351364
@set! newbuf.dependent = newbuf.dependent_update_oop(newbuf...)
352365
end

test/mtkparameters.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,12 @@ u0 = [X => 1.0]
112112
ps = [p => [2.0, 0.1]]
113113
p = MTKParameters(osys, ps, u0)
114114
@test p.tunable[1] == [2.0, 0.1]
115+
116+
# Ensure partial update promotes the buffer
117+
@parameters p q r
118+
@named sys = ODESystem(Equation[], t, [], [p, q, r])
119+
sys = complete(sys)
120+
ps = MTKParameters(sys, [p => 1.0, q => 2.0, r => 3.0])
121+
newps = remake_buffer(sys, ps, Dict(p => 1.0f0))
122+
@test newps.tunable[1] isa Vector{Float32}
123+
@test newps.tunable[1] == [1.0f0, 2.0f0, 3.0f0]

0 commit comments

Comments
 (0)