Skip to content

Commit 8f75e90

Browse files
Merge pull request #3020 from AayushSabharwal/as/fix-remake-buffer
feat: update to new `remake_buffer` signature
2 parents f1b8e72 + ee8a5e9 commit 8f75e90

File tree

4 files changed

+130
-112
lines changed

4 files changed

+130
-112
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,15 @@ PrecompileTools = "1"
109109
RecursiveArrayTools = "3.26"
110110
Reexport = "0.2, 1"
111111
RuntimeGeneratedFunctions = "0.5.9"
112-
SciMLBase = "2.46"
112+
SciMLBase = "2.52.1"
113113
SciMLStructures = "1.0"
114114
Serialization = "1"
115115
Setfield = "0.7, 0.8, 1"
116116
SimpleNonlinearSolve = "0.1.0, 1"
117117
SparseArrays = "1"
118118
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
119119
StaticArrays = "0.10, 0.11, 0.12, 1.0"
120-
SymbolicIndexingInterface = "0.3.28"
120+
SymbolicIndexingInterface = "0.3.29"
121121
SymbolicUtils = "3.2"
122122
Symbolics = "6.3"
123123
URIs = "1"

src/systems/index_cache.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ struct ParameterIndex{P, I}
1818
end
1919

2020
ParameterIndex(portion, idx) = ParameterIndex(portion, idx, false)
21+
ParameterIndex(p::ParameterIndex) = ParameterIndex(p.portion, p.idx, false)
2122

2223
struct DiscreteIndex
2324
# of all buffers corresponding to types, which one
@@ -318,7 +319,8 @@ function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym)
318319
sym = get(ic.symbol_to_variable, sym, nothing)
319320
sym === nothing && return nothing
320321
end
321-
validate_size = Symbolics.isarraysymbolic(sym) &&
322+
sym = unwrap(sym)
323+
validate_size = Symbolics.isarraysymbolic(sym) && symtype(sym) <: AbstractArray &&
322324
Symbolics.shape(sym) !== Symbolics.Unknown()
323325
return if (idx = check_index_map(ic.tunable_idx, sym)) !== nothing
324326
ParameterIndex(SciMLStructures.Tunable(), idx, validate_size)
@@ -459,3 +461,19 @@ function iterated_buffer_index(ic::IndexCache, ind::ParameterIndex)
459461
end
460462
error("Unhandled portion $(ind.portion)")
461463
end
464+
465+
function get_buffer_template(ic::IndexCache, pidx::ParameterIndex)
466+
(; portion, idx) = pidx
467+
468+
if portion isa SciMLStructures.Tunable
469+
return ic.tunable_buffer_size
470+
elseif portion isa SciMLStructures.Discrete
471+
return ic.discrete_buffer_sizes[idx[1]][1]
472+
elseif portion isa SciMLStructures.Constants
473+
return ic.constant_buffer_sizes[idx[1]]
474+
elseif portion isa Nonnumeric
475+
return ic.nonnumeric_buffer_sizes[idx[1]]
476+
else
477+
error("Unhandled portion $portion")
478+
end
479+
end

src/systems/parameter_buffer.jl

Lines changed: 88 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -396,37 +396,6 @@ function SymbolicIndexingInterface.set_parameter!(
396396
return nothing
397397
end
398398

399-
function _set_parameter_unchecked!(
400-
p::MTKParameters, val, idx::ParameterIndex; update_dependent = true)
401-
@unpack portion, idx = idx
402-
if portion isa SciMLStructures.Tunable
403-
p.tunable[idx] = val
404-
else
405-
i, j, k... = idx
406-
if portion isa SciMLStructures.Discrete
407-
if isempty(k)
408-
p.discrete[i][j] = val
409-
else
410-
p.discrete[i][j][k...] = val
411-
end
412-
elseif portion isa SciMLStructures.Constants
413-
if isempty(k)
414-
p.constant[i][j] = val
415-
else
416-
p.constant[i][j][k...] = val
417-
end
418-
elseif portion === NONNUMERIC_PORTION
419-
if isempty(k)
420-
p.nonnumeric[i][j] = val
421-
else
422-
p.nonnumeric[i][j][k...] = val
423-
end
424-
else
425-
error("Unhandled portion $portion")
426-
end
427-
end
428-
end
429-
430399
function narrow_buffer_type_and_fallback_undefs(
431400
oldbuf::AbstractVector, newbuf::AbstractVector)
432401
type = Union{}
@@ -448,31 +417,42 @@ function narrow_buffer_type_and_fallback_undefs(
448417
return newerbuf
449418
end
450419

451-
function validate_parameter_type(ic::IndexCache, p, index, val)
420+
function validate_parameter_type(ic::IndexCache, p, idx::ParameterIndex, val)
452421
p = unwrap(p)
453422
if p isa Symbol
454423
p = get(ic.symbol_to_variable, p, nothing)
455-
if p === nothing
456-
@warn "No matching variable found for `Symbol` $p, skipping type validation."
457-
return nothing
458-
end
424+
p === nothing && return validate_parameter_type(ic, idx, val)
425+
end
426+
stype = symtype(p)
427+
sz = if stype <: AbstractArray
428+
Symbolics.shape(p) == Symbolics.Unknown() ? Symbolics.Unknown() : size(p)
429+
elseif stype <: Number
430+
size(p)
431+
else
432+
Symbolics.Unknown()
459433
end
434+
validate_parameter_type(ic, stype, sz, p, idx, val)
435+
end
436+
437+
function validate_parameter_type(ic::IndexCache, idx::ParameterIndex, val)
438+
validate_parameter_type(
439+
ic, get_buffer_template(ic, idx).type, Symbolics.Unknown(), nothing, idx, val)
440+
end
441+
442+
function validate_parameter_type(ic::IndexCache, stype, sz, sym, index, val)
460443
(; portion) = index
461444
# Nonnumeric parameters have to match the type
462445
if portion === NONNUMERIC_PORTION
463-
stype = symtype(p)
464446
val isa stype && return nothing
465-
throw(ParameterTypeException(:validate_parameter_type, p, stype, val))
447+
throw(ParameterTypeException(:validate_parameter_type, sym, stype, val))
466448
end
467-
stype = symtype(p)
468449
# Array parameters need array values...
469450
if stype <: AbstractArray && !isa(val, AbstractArray)
470-
throw(ParameterTypeException(:validate_parameter_type, p, stype, val))
451+
throw(ParameterTypeException(:validate_parameter_type, sym, stype, val))
471452
end
472453
# ... and must match sizes
473-
if stype <: AbstractArray && Symbolics.shape(p) !== Symbolics.Unknown() &&
474-
size(val) != size(p)
475-
throw(InvalidParameterSizeException(p, val))
454+
if stype <: AbstractArray && sz != Symbolics.Unknown() && size(val) != sz
455+
throw(InvalidParameterSizeException(sym, val))
476456
end
477457
# Early exit
478458
val isa stype && return nothing
@@ -485,15 +465,15 @@ function validate_parameter_type(ic::IndexCache, p, index, val)
485465
# This is for duals and other complicated number types
486466
etype = SciMLBase.parameterless_type(etype)
487467
eltype(val) <: etype || throw(ParameterTypeException(
488-
:validate_parameter_type, p, AbstractArray{etype}, val))
468+
:validate_parameter_type, sym, AbstractArray{etype}, val))
489469
else
490470
# Real check
491471
if stype <: Real
492472
stype = Real
493473
end
494474
stype = SciMLBase.parameterless_type(stype)
495475
val isa stype ||
496-
throw(ParameterTypeException(:validate_parameter_type, p, stype, val))
476+
throw(ParameterTypeException(:validate_parameter_type, sym, stype, val))
497477
end
498478
end
499479

@@ -504,45 +484,69 @@ function indp_to_system(indp)
504484
return indp
505485
end
506486

507-
function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, vals::Dict)
508-
newbuf = @set oldbuf.tunable = Vector{Any}(undef, length(oldbuf.tunable))
487+
function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, idxs, vals)
488+
newbuf = @set oldbuf.tunable = similar(oldbuf.tunable, Any)
509489
@set! newbuf.discrete = Tuple(similar(buf, Any) for buf in newbuf.discrete)
510-
@set! newbuf.constant = Tuple(Vector{Any}(undef, length(buf))
511-
for buf in newbuf.constant)
512-
@set! newbuf.nonnumeric = Tuple(Vector{Any}(undef, length(buf))
513-
for buf in newbuf.nonnumeric)
514-
515-
syms = collect(keys(vals))
516-
vals = Dict{Any, Any}(vals)
517-
for sym in syms
518-
symbolic_type(sym) == ArraySymbolic() || continue
519-
is_parameter(indp, sym) && continue
520-
stype = symtype(unwrap(sym))
521-
stype <: AbstractArray || continue
522-
Symbolics.shape(sym) == Symbolics.Unknown() && continue
523-
for i in eachindex(sym)
524-
vals[sym[i]] = vals[sym][i]
490+
@set! newbuf.constant = Tuple(similar(buf, Any) for buf in newbuf.constant)
491+
@set! newbuf.nonnumeric = Tuple(similar(buf, Any) for buf in newbuf.nonnumeric)
492+
493+
function handle_parameter(ic, sym, idx, val)
494+
if sym === nothing
495+
validate_parameter_type(ic, idx, val)
496+
else
497+
validate_parameter_type(ic, sym, idx, val)
525498
end
499+
# `ParameterIndex(idx)` turns off size validation since it relies on there
500+
# being an existing value
501+
set_parameter!(newbuf, val, ParameterIndex(idx))
526502
end
527503

504+
handled_idxs = Set{ParameterIndex}()
528505
# If the parameter buffer is an `MTKParameters` object, `indp` must eventually drill
529506
# down to an `AbstractSystem` using `symbolic_container`. We leverage this to get
530507
# the index cache.
531508
ic = get_index_cache(indp_to_system(indp))
532-
for (p, val) in vals
533-
idx = parameter_index(indp, p)
534-
if idx !== nothing
535-
validate_parameter_type(ic, p, idx, val)
536-
_set_parameter_unchecked!(
537-
newbuf, val, idx; update_dependent = false)
538-
elseif symbolic_type(p) == ArraySymbolic()
539-
for (i, j) in zip(eachindex(p), eachindex(val))
540-
pi = p[i]
541-
idx = parameter_index(indp, pi)
542-
validate_parameter_type(ic, pi, idx, val[j])
543-
_set_parameter_unchecked!(
544-
newbuf, val[j], idx; update_dependent = false)
509+
for (idx, val) in zip(idxs, vals)
510+
sym = nothing
511+
if symbolic_type(idx) == ScalarSymbolic()
512+
sym = idx
513+
idx = parameter_index(ic, sym)
514+
if idx === nothing
515+
@warn "Symbolic variable $sym is not a (non-dependent) parameter in the system"
516+
continue
517+
end
518+
idx in handled_idxs && continue
519+
handle_parameter(ic, sym, idx, val)
520+
push!(handled_idxs, idx)
521+
elseif symbolic_type(idx) == ArraySymbolic()
522+
sym = idx
523+
idx = parameter_index(ic, sym)
524+
if idx === nothing
525+
Symbolics.shape(sym) == Symbolics.Unknown() &&
526+
throw(ParameterNotInSystem(sym))
527+
size(sym) == size(val) || throw(InvalidParameterSizeException(sym, val))
528+
529+
for (i, vali) in zip(eachindex(sym), eachindex(val))
530+
idx = parameter_index(ic, sym[i])
531+
if idx === nothing
532+
@warn "Symbolic variable $sym is not a (non-dependent) parameter in the system"
533+
continue
534+
end
535+
# Intentionally don't check handled_idxs here because array variables always take priority
536+
# See Issue#2804
537+
handle_parameter(ic, sym[i], idx, val[vali])
538+
push!(handled_idxs, idx)
539+
end
540+
else
541+
idx in handled_idxs && continue
542+
handle_parameter(ic, sym, idx, val)
543+
push!(handled_idxs, idx)
545544
end
545+
else # NotSymbolic
546+
if !(idx isa ParameterIndex)
547+
throw(ArgumentError("Expected index for parameter to be a symbolic variable or `ParameterIndex`, got $idx"))
548+
end
549+
handle_parameter(ic, nothing, idx, val)
546550
end
547551
end
548552

@@ -688,7 +692,7 @@ function jacobian_wrt_vars(pf::F, p::MTKParameters, input_idxs, chunk::C) where
688692

689693
function (p_small_inner)
690694
for (i, val) in zip(input_idxs, p_small_inner)
691-
_set_parameter_unchecked!(p_big, val, i)
695+
set_parameter!(p_big, val, i)
692696
end
693697
return if pf isa SciMLBase.ParamJacobianWrapper
694698
buffer = Array{dualtype}(undef, size(pf.u))
@@ -735,3 +739,11 @@ end
735739
function ParameterTypeException(func, param, expected, val)
736740
TypeError(func, "Parameter $param", expected, val)
737741
end
742+
743+
struct ParameterNotInSystem <: Exception
744+
p::Any
745+
end
746+
747+
function Base.showerror(io::IO, e::ParameterNotInSystem)
748+
println(io, "Symbolic variable $(e.p) is not a parameter in the system")
749+
end

test/mtkparameters.jl

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using JET
99

1010
@parameters a b c(t) d::Integer e[1:3] f[1:3, 1:3]::Int g::Vector{AbstractFloat} h::String
1111
@named sys = ODESystem(
12-
Equation[], t, [], [a, c, d, e, f, g, h], parameter_dependencies = [b => 2a],
12+
Equation[], t, [], [a, c, d, e, f, g, h], parameter_dependencies = [b ~ 2a],
1313
continuous_events = [[a ~ 0] => [c ~ 0]], defaults = Dict(a => 0.0))
1414
sys = complete(sys)
1515

@@ -72,10 +72,12 @@ setp(sys, g)(ps, ones(100)) # with non-fixed-length array
7272
setp(sys, h)(ps, "bar") # with a non-numeric
7373
@test getp(sys, h)(ps) == "bar"
7474

75-
newps = remake_buffer(sys,
76-
ps,
77-
Dict(a => 1.0f0, b => 5.0f0, c => 2.0, d => 0x5, e => Float32[0.4, 0.5, 0.6],
78-
f => 3ones(UInt, 3, 3), g => ones(Float32, 4), h => "bar"))
75+
varmap = Dict(a => 1.0f0, b => 5.0f0, c => 2.0, d => 0x5, e => Float32[0.4, 0.5, 0.6],
76+
f => 3ones(UInt, 3, 3), g => ones(Float32, 4), h => "bar")
77+
@test_deprecated remake_buffer(sys, ps, varmap)
78+
@test_warn ["Symbolic variable b", "non-dependent", "parameter"] remake_buffer(
79+
sys, ps, keys(varmap), values(varmap))
80+
newps = remake_buffer(sys, ps, keys(varmap), values(varmap))
7981

8082
for fname in (:tunable, :discrete, :constant)
8183
# ensure same number of sub-buffers
@@ -92,8 +94,7 @@ end
9294
ps = MTKParameters(sys, ivs)
9395
function loss(value, sys, ps)
9496
@test value isa ForwardDiff.Dual
95-
vals = merge(Dict(parameters(sys) .=> getp(sys, parameters(sys))(ps)), Dict(a => value))
96-
ps = remake_buffer(sys, ps, vals)
97+
ps = remake_buffer(sys, ps, (a,), (value,))
9798
getp(sys, a)(ps) + getp(sys, b)(ps)
9899
end
99100

@@ -115,7 +116,7 @@ p = MTKParameters(osys, ps, u0)
115116
@named sys = ODESystem(Equation[], t, [], [p, q, r])
116117
sys = complete(sys)
117118
ps = MTKParameters(sys, [p => 1.0, q => 2.0, r => 3.0])
118-
newps = remake_buffer(sys, ps, Dict(p => 1.0f0))
119+
newps = remake_buffer(sys, ps, (p,), (1.0f0,))
119120
@test newps.tunable isa Vector{Float32}
120121
@test newps.tunable == [1.0f0, 2.0f0, 3.0f0]
121122

@@ -227,19 +228,6 @@ end
227228

228229
@test_nowarn ForwardDiff.gradient(loss, collect(tunables))
229230

230-
# Ensure dependent parameters are `Tuple{...}` and not `ArrayPartition` when using
231-
# `remake_buffer`.
232-
@parameters p1 p2 p3[1:2] p4[1:2]
233-
@named sys = ODESystem(
234-
Equation[], t, [], [p1, p2, p3, p4]; parameter_dependencies = [p2 => 2p1, p4 => 3p3])
235-
sys = complete(sys)
236-
ps = MTKParameters(sys, [p1 => 1.0, p3 => [2.0, 3.0]])
237-
@test getp(sys, p2)(ps) == 2.0
238-
@test getp(sys, p4)(ps) == [6.0, 9.0]
239-
240-
newps = remake_buffer(
241-
sys, ps, Dict(p1 => ForwardDiff.Dual(2.0), p3 => ForwardDiff.Dual.([3.0, 4.0])))
242-
243231
VDual = Vector{<:ForwardDiff.Dual}
244232
VVDual = Vector{<:Vector{<:ForwardDiff.Dual}}
245233

@@ -263,26 +251,26 @@ VVDual = Vector{<:Vector{<:ForwardDiff.Dual}}
263251

264252
# Same flexibility is afforded to `b::Int` to allow for ForwardDiff
265253
for sym in [a, b]
266-
@test_nowarn remake_buffer(sys, ps, Dict(sym => 1))
267-
newps = @test_nowarn remake_buffer(sys, ps, Dict(sym => 1.0f0)) # Can change type if it's numeric
254+
@test_nowarn remake_buffer(sys, ps, (sym,), (1,))
255+
newps = @test_nowarn remake_buffer(sys, ps, (sym,), (1.0f0,)) # Can change type if it's numeric
268256
@test getp(sys, sym)(newps) isa Float32
269-
newps = @test_nowarn remake_buffer(sys, ps, Dict(sym => ForwardDiff.Dual(1.0)))
257+
newps = @test_nowarn remake_buffer(sys, ps, sym, ForwardDiff.Dual(1.0))
270258
@test getp(sys, sym)(newps) isa ForwardDiff.Dual
271-
@test_throws TypeError remake_buffer(sys, ps, Dict(sym => :a)) # still has to be numeric
259+
@test_throws TypeError remake_buffer(sys, ps, (sym,), (:a,)) # still has to be numeric
272260
end
273261

274-
newps = @test_nowarn remake_buffer(sys, ps, Dict(c => view(1.0:4.0, 2:4))) # can change type of array
262+
newps = @test_nowarn remake_buffer(sys, ps, (c,), (view(1.0:4.0, 2:4),)) # can change type of array
275263
@test getp(sys, c)(newps) == 2.0:4.0
276264
@test parameter_values(newps, parameter_index(sys, c)) [2.0, 3.0, 4.0]
277-
@test_throws TypeError remake_buffer(sys, ps, Dict(c => [:a, :b, :c])) # can't arbitrarily change eltype
278-
@test_throws TypeError remake_buffer(sys, ps, Dict(c => :a)) # can't arbitrarily change type
265+
@test_throws TypeError remake_buffer(sys, ps, (c,), ([:a, :b, :c],)) # can't arbitrarily change eltype
266+
@test_throws TypeError remake_buffer(sys, ps, (c,), (:a,)) # can't arbitrarily change type
279267

280-
newps = @test_nowarn remake_buffer(sys, ps, Dict(d => ForwardDiff.Dual.(ones(2, 2)))) # can change eltype
281-
@test_throws TypeError remake_buffer(sys, ps, Dict(d => [:a :b; :c :d])) # eltype still has to be numeric
268+
newps = @test_nowarn remake_buffer(sys, ps, (d,), (ForwardDiff.Dual.(ones(2, 2)),)) # can change eltype
269+
@test_throws TypeError remake_buffer(sys, ps, (d,), ([:a :b; :c :d],)) # eltype still has to be numeric
282270
@test getp(sys, d)(newps) isa Matrix{<:ForwardDiff.Dual}
283271

284-
@test_throws TypeError remake_buffer(sys, ps, Dict(e => Foo(2.0))) # need exact same type for nonnumeric
285-
@test_nowarn remake_buffer(sys, ps, Dict(f => Foo(:a)))
272+
@test_throws TypeError remake_buffer(sys, ps, (e,), (Foo(2.0),)) # need exact same type for nonnumeric
273+
@test_nowarn remake_buffer(sys, ps, (f,), (Foo(:a),))
286274
end
287275

288276
@testset "Error on missing parameter defaults" begin
@@ -292,7 +280,7 @@ end
292280
@test_throws ["Could not evaluate", "b", "Missing", "2c"] MTKParameters(sys, [a => 1.0])
293281
end
294282

295-
@testset "Issue#3804" begin
283+
@testset "Issue#2804" begin
296284
@parameters k[1:4]
297285
@variables (V(t))[1:2]
298286
eqs = [

0 commit comments

Comments
 (0)