@@ -396,37 +396,6 @@ function SymbolicIndexingInterface.set_parameter!(
396396 return nothing
397397end
398398
399- function _set_parameter_unchecked! (
400- p:: MTKParameters , val, idx:: ParameterIndex ; update_dependent = true )
401- @unpack portion, idx = idx
402- if portion isa SciMLStructures. Tunable
403- p. tunable[idx] = val
404- else
405- i, j, k... = idx
406- if portion isa SciMLStructures. Discrete
407- if isempty (k)
408- p. discrete[i][j] = val
409- else
410- p. discrete[i][j][k... ] = val
411- end
412- elseif portion isa SciMLStructures. Constants
413- if isempty (k)
414- p. constant[i][j] = val
415- else
416- p. constant[i][j][k... ] = val
417- end
418- elseif portion === NONNUMERIC_PORTION
419- if isempty (k)
420- p. nonnumeric[i][j] = val
421- else
422- p. nonnumeric[i][j][k... ] = val
423- end
424- else
425- error (" Unhandled portion $portion " )
426- end
427- end
428- end
429-
430399function narrow_buffer_type_and_fallback_undefs (
431400 oldbuf:: AbstractVector , newbuf:: AbstractVector )
432401 type = Union{}
@@ -448,31 +417,42 @@ function narrow_buffer_type_and_fallback_undefs(
448417 return newerbuf
449418end
450419
451- function validate_parameter_type (ic:: IndexCache , p, index , val)
420+ function validate_parameter_type (ic:: IndexCache , p, idx :: ParameterIndex , val)
452421 p = unwrap (p)
453422 if p isa Symbol
454423 p = get (ic. symbol_to_variable, p, nothing )
455- if p === nothing
456- @warn " No matching variable found for `Symbol` $p , skipping type validation."
457- return nothing
458- end
424+ p === nothing && return validate_parameter_type (ic, idx, val)
425+ end
426+ stype = symtype (p)
427+ sz = if stype <: AbstractArray
428+ Symbolics. shape (p) == Symbolics. Unknown () ? Symbolics. Unknown () : size (p)
429+ elseif stype <: Number
430+ size (p)
431+ else
432+ Symbolics. Unknown ()
459433 end
434+ validate_parameter_type (ic, stype, sz, p, idx, val)
435+ end
436+
437+ function validate_parameter_type (ic:: IndexCache , idx:: ParameterIndex , val)
438+ validate_parameter_type (
439+ ic, get_buffer_template (ic, idx). type, Symbolics. Unknown (), nothing , idx, val)
440+ end
441+
442+ function validate_parameter_type (ic:: IndexCache , stype, sz, sym, index, val)
460443 (; portion) = index
461444 # Nonnumeric parameters have to match the type
462445 if portion === NONNUMERIC_PORTION
463- stype = symtype (p)
464446 val isa stype && return nothing
465- throw (ParameterTypeException (:validate_parameter_type , p , stype, val))
447+ throw (ParameterTypeException (:validate_parameter_type , sym , stype, val))
466448 end
467- stype = symtype (p)
468449 # Array parameters need array values...
469450 if stype <: AbstractArray && ! isa (val, AbstractArray)
470- throw (ParameterTypeException (:validate_parameter_type , p , stype, val))
451+ throw (ParameterTypeException (:validate_parameter_type , sym , stype, val))
471452 end
472453 # ... and must match sizes
473- if stype <: AbstractArray && Symbolics. shape (p) != = Symbolics. Unknown () &&
474- size (val) != size (p)
475- throw (InvalidParameterSizeException (p, val))
454+ if stype <: AbstractArray && sz != Symbolics. Unknown () && size (val) != sz
455+ throw (InvalidParameterSizeException (sym, val))
476456 end
477457 # Early exit
478458 val isa stype && return nothing
@@ -485,15 +465,15 @@ function validate_parameter_type(ic::IndexCache, p, index, val)
485465 # This is for duals and other complicated number types
486466 etype = SciMLBase. parameterless_type (etype)
487467 eltype (val) <: etype || throw (ParameterTypeException (
488- :validate_parameter_type , p , AbstractArray{etype}, val))
468+ :validate_parameter_type , sym , AbstractArray{etype}, val))
489469 else
490470 # Real check
491471 if stype <: Real
492472 stype = Real
493473 end
494474 stype = SciMLBase. parameterless_type (stype)
495475 val isa stype ||
496- throw (ParameterTypeException (:validate_parameter_type , p , stype, val))
476+ throw (ParameterTypeException (:validate_parameter_type , sym , stype, val))
497477 end
498478end
499479
@@ -504,45 +484,69 @@ function indp_to_system(indp)
504484 return indp
505485end
506486
507- function SymbolicIndexingInterface. remake_buffer (indp, oldbuf:: MTKParameters , vals:: Dict )
508- newbuf = @set oldbuf. tunable = Vector {Any} (undef, length ( oldbuf. tunable) )
487+ function SymbolicIndexingInterface. remake_buffer (indp, oldbuf:: MTKParameters , idxs, vals)
488+ newbuf = @set oldbuf. tunable = similar ( oldbuf. tunable, Any )
509489 @set! newbuf. discrete = Tuple (similar (buf, Any) for buf in newbuf. discrete)
510- @set! newbuf. constant = Tuple (Vector {Any} (undef, length (buf))
511- for buf in newbuf. constant)
512- @set! newbuf. nonnumeric = Tuple (Vector {Any} (undef, length (buf))
513- for buf in newbuf. nonnumeric)
514-
515- syms = collect (keys (vals))
516- vals = Dict {Any, Any} (vals)
517- for sym in syms
518- symbolic_type (sym) == ArraySymbolic () || continue
519- is_parameter (indp, sym) && continue
520- stype = symtype (unwrap (sym))
521- stype <: AbstractArray || continue
522- Symbolics. shape (sym) == Symbolics. Unknown () && continue
523- for i in eachindex (sym)
524- vals[sym[i]] = vals[sym][i]
490+ @set! newbuf. constant = Tuple (similar (buf, Any) for buf in newbuf. constant)
491+ @set! newbuf. nonnumeric = Tuple (similar (buf, Any) for buf in newbuf. nonnumeric)
492+
493+ function handle_parameter (ic, sym, idx, val)
494+ if sym === nothing
495+ validate_parameter_type (ic, idx, val)
496+ else
497+ validate_parameter_type (ic, sym, idx, val)
525498 end
499+ # `ParameterIndex(idx)` turns off size validation since it relies on there
500+ # being an existing value
501+ set_parameter! (newbuf, val, ParameterIndex (idx))
526502 end
527503
504+ handled_idxs = Set {ParameterIndex} ()
528505 # If the parameter buffer is an `MTKParameters` object, `indp` must eventually drill
529506 # down to an `AbstractSystem` using `symbolic_container`. We leverage this to get
530507 # the index cache.
531508 ic = get_index_cache (indp_to_system (indp))
532- for (p, val) in vals
533- idx = parameter_index (indp, p)
534- if idx != = nothing
535- validate_parameter_type (ic, p, idx, val)
536- _set_parameter_unchecked! (
537- newbuf, val, idx; update_dependent = false )
538- elseif symbolic_type (p) == ArraySymbolic ()
539- for (i, j) in zip (eachindex (p), eachindex (val))
540- pi = p[i]
541- idx = parameter_index (indp, pi )
542- validate_parameter_type (ic, pi , idx, val[j])
543- _set_parameter_unchecked! (
544- newbuf, val[j], idx; update_dependent = false )
509+ for (idx, val) in zip (idxs, vals)
510+ sym = nothing
511+ if symbolic_type (idx) == ScalarSymbolic ()
512+ sym = idx
513+ idx = parameter_index (ic, sym)
514+ if idx === nothing
515+ @warn " Symbolic variable $sym is not a (non-dependent) parameter in the system"
516+ continue
517+ end
518+ idx in handled_idxs && continue
519+ handle_parameter (ic, sym, idx, val)
520+ push! (handled_idxs, idx)
521+ elseif symbolic_type (idx) == ArraySymbolic ()
522+ sym = idx
523+ idx = parameter_index (ic, sym)
524+ if idx === nothing
525+ Symbolics. shape (sym) == Symbolics. Unknown () &&
526+ throw (ParameterNotInSystem (sym))
527+ size (sym) == size (val) || throw (InvalidParameterSizeException (sym, val))
528+
529+ for (i, vali) in zip (eachindex (sym), eachindex (val))
530+ idx = parameter_index (ic, sym[i])
531+ if idx === nothing
532+ @warn " Symbolic variable $sym is not a (non-dependent) parameter in the system"
533+ continue
534+ end
535+ # Intentionally don't check handled_idxs here because array variables always take priority
536+ # See Issue#2804
537+ handle_parameter (ic, sym[i], idx, val[vali])
538+ push! (handled_idxs, idx)
539+ end
540+ else
541+ idx in handled_idxs && continue
542+ handle_parameter (ic, sym, idx, val)
543+ push! (handled_idxs, idx)
545544 end
545+ else # NotSymbolic
546+ if ! (idx isa ParameterIndex)
547+ throw (ArgumentError (" Expected index for parameter to be a symbolic variable or `ParameterIndex`, got $idx " ))
548+ end
549+ handle_parameter (ic, nothing , idx, val)
546550 end
547551 end
548552
@@ -688,7 +692,7 @@ function jacobian_wrt_vars(pf::F, p::MTKParameters, input_idxs, chunk::C) where
688692
689693 function (p_small_inner)
690694 for (i, val) in zip (input_idxs, p_small_inner)
691- _set_parameter_unchecked ! (p_big, val, i)
695+ set_parameter ! (p_big, val, i)
692696 end
693697 return if pf isa SciMLBase. ParamJacobianWrapper
694698 buffer = Array {dualtype} (undef, size (pf. u))
735739function ParameterTypeException (func, param, expected, val)
736740 TypeError (func, " Parameter $param " , expected, val)
737741end
742+
743+ struct ParameterNotInSystem <: Exception
744+ p:: Any
745+ end
746+
747+ function Base. showerror (io:: IO , e:: ParameterNotInSystem )
748+ println (io, " Symbolic variable $(e. p) is not a parameter in the system" )
749+ end
0 commit comments