@@ -396,37 +396,6 @@ function SymbolicIndexingInterface.set_parameter!(
396
396
return nothing
397
397
end
398
398
399
- function _set_parameter_unchecked! (
400
- p:: MTKParameters , val, idx:: ParameterIndex ; update_dependent = true )
401
- @unpack portion, idx = idx
402
- if portion isa SciMLStructures. Tunable
403
- p. tunable[idx] = val
404
- else
405
- i, j, k... = idx
406
- if portion isa SciMLStructures. Discrete
407
- if isempty (k)
408
- p. discrete[i][j] = val
409
- else
410
- p. discrete[i][j][k... ] = val
411
- end
412
- elseif portion isa SciMLStructures. Constants
413
- if isempty (k)
414
- p. constant[i][j] = val
415
- else
416
- p. constant[i][j][k... ] = val
417
- end
418
- elseif portion === NONNUMERIC_PORTION
419
- if isempty (k)
420
- p. nonnumeric[i][j] = val
421
- else
422
- p. nonnumeric[i][j][k... ] = val
423
- end
424
- else
425
- error (" Unhandled portion $portion " )
426
- end
427
- end
428
- end
429
-
430
399
function narrow_buffer_type_and_fallback_undefs (
431
400
oldbuf:: AbstractVector , newbuf:: AbstractVector )
432
401
type = Union{}
@@ -448,31 +417,42 @@ function narrow_buffer_type_and_fallback_undefs(
448
417
return newerbuf
449
418
end
450
419
451
- function validate_parameter_type (ic:: IndexCache , p, index , val)
420
+ function validate_parameter_type (ic:: IndexCache , p, idx :: ParameterIndex , val)
452
421
p = unwrap (p)
453
422
if p isa Symbol
454
423
p = get (ic. symbol_to_variable, p, nothing )
455
- if p === nothing
456
- @warn " No matching variable found for `Symbol` $p , skipping type validation."
457
- return nothing
458
- end
424
+ p === nothing && return validate_parameter_type (ic, idx, val)
425
+ end
426
+ stype = symtype (p)
427
+ sz = if stype <: AbstractArray
428
+ Symbolics. shape (p) == Symbolics. Unknown () ? Symbolics. Unknown () : size (p)
429
+ elseif stype <: Number
430
+ size (p)
431
+ else
432
+ Symbolics. Unknown ()
459
433
end
434
+ validate_parameter_type (ic, stype, sz, p, idx, val)
435
+ end
436
+
437
+ function validate_parameter_type (ic:: IndexCache , idx:: ParameterIndex , val)
438
+ validate_parameter_type (
439
+ ic, get_buffer_template (ic, idx). type, Symbolics. Unknown (), nothing , idx, val)
440
+ end
441
+
442
+ function validate_parameter_type (ic:: IndexCache , stype, sz, sym, index, val)
460
443
(; portion) = index
461
444
# Nonnumeric parameters have to match the type
462
445
if portion === NONNUMERIC_PORTION
463
- stype = symtype (p)
464
446
val isa stype && return nothing
465
- throw (ParameterTypeException (:validate_parameter_type , p , stype, val))
447
+ throw (ParameterTypeException (:validate_parameter_type , sym , stype, val))
466
448
end
467
- stype = symtype (p)
468
449
# Array parameters need array values...
469
450
if stype <: AbstractArray && ! isa (val, AbstractArray)
470
- throw (ParameterTypeException (:validate_parameter_type , p , stype, val))
451
+ throw (ParameterTypeException (:validate_parameter_type , sym , stype, val))
471
452
end
472
453
# ... and must match sizes
473
- if stype <: AbstractArray && Symbolics. shape (p) != = Symbolics. Unknown () &&
474
- size (val) != size (p)
475
- throw (InvalidParameterSizeException (p, val))
454
+ if stype <: AbstractArray && sz != Symbolics. Unknown () && size (val) != sz
455
+ throw (InvalidParameterSizeException (sym, val))
476
456
end
477
457
# Early exit
478
458
val isa stype && return nothing
@@ -485,15 +465,15 @@ function validate_parameter_type(ic::IndexCache, p, index, val)
485
465
# This is for duals and other complicated number types
486
466
etype = SciMLBase. parameterless_type (etype)
487
467
eltype (val) <: etype || throw (ParameterTypeException (
488
- :validate_parameter_type , p , AbstractArray{etype}, val))
468
+ :validate_parameter_type , sym , AbstractArray{etype}, val))
489
469
else
490
470
# Real check
491
471
if stype <: Real
492
472
stype = Real
493
473
end
494
474
stype = SciMLBase. parameterless_type (stype)
495
475
val isa stype ||
496
- throw (ParameterTypeException (:validate_parameter_type , p , stype, val))
476
+ throw (ParameterTypeException (:validate_parameter_type , sym , stype, val))
497
477
end
498
478
end
499
479
@@ -504,45 +484,69 @@ function indp_to_system(indp)
504
484
return indp
505
485
end
506
486
507
- function SymbolicIndexingInterface. remake_buffer (indp, oldbuf:: MTKParameters , vals:: Dict )
508
- newbuf = @set oldbuf. tunable = Vector {Any} (undef, length ( oldbuf. tunable) )
487
+ function SymbolicIndexingInterface. remake_buffer (indp, oldbuf:: MTKParameters , idxs, vals)
488
+ newbuf = @set oldbuf. tunable = similar ( oldbuf. tunable, Any )
509
489
@set! newbuf. discrete = Tuple (similar (buf, Any) for buf in newbuf. discrete)
510
- @set! newbuf. constant = Tuple (Vector {Any} (undef, length (buf))
511
- for buf in newbuf. constant)
512
- @set! newbuf. nonnumeric = Tuple (Vector {Any} (undef, length (buf))
513
- for buf in newbuf. nonnumeric)
514
-
515
- syms = collect (keys (vals))
516
- vals = Dict {Any, Any} (vals)
517
- for sym in syms
518
- symbolic_type (sym) == ArraySymbolic () || continue
519
- is_parameter (indp, sym) && continue
520
- stype = symtype (unwrap (sym))
521
- stype <: AbstractArray || continue
522
- Symbolics. shape (sym) == Symbolics. Unknown () && continue
523
- for i in eachindex (sym)
524
- vals[sym[i]] = vals[sym][i]
490
+ @set! newbuf. constant = Tuple (similar (buf, Any) for buf in newbuf. constant)
491
+ @set! newbuf. nonnumeric = Tuple (similar (buf, Any) for buf in newbuf. nonnumeric)
492
+
493
+ function handle_parameter (ic, sym, idx, val)
494
+ if sym === nothing
495
+ validate_parameter_type (ic, idx, val)
496
+ else
497
+ validate_parameter_type (ic, sym, idx, val)
525
498
end
499
+ # `ParameterIndex(idx)` turns off size validation since it relies on there
500
+ # being an existing value
501
+ set_parameter! (newbuf, val, ParameterIndex (idx))
526
502
end
527
503
504
+ handled_idxs = Set {ParameterIndex} ()
528
505
# If the parameter buffer is an `MTKParameters` object, `indp` must eventually drill
529
506
# down to an `AbstractSystem` using `symbolic_container`. We leverage this to get
530
507
# the index cache.
531
508
ic = get_index_cache (indp_to_system (indp))
532
- for (p, val) in vals
533
- idx = parameter_index (indp, p)
534
- if idx != = nothing
535
- validate_parameter_type (ic, p, idx, val)
536
- _set_parameter_unchecked! (
537
- newbuf, val, idx; update_dependent = false )
538
- elseif symbolic_type (p) == ArraySymbolic ()
539
- for (i, j) in zip (eachindex (p), eachindex (val))
540
- pi = p[i]
541
- idx = parameter_index (indp, pi )
542
- validate_parameter_type (ic, pi , idx, val[j])
543
- _set_parameter_unchecked! (
544
- newbuf, val[j], idx; update_dependent = false )
509
+ for (idx, val) in zip (idxs, vals)
510
+ sym = nothing
511
+ if symbolic_type (idx) == ScalarSymbolic ()
512
+ sym = idx
513
+ idx = parameter_index (ic, sym)
514
+ if idx === nothing
515
+ @warn " Symbolic variable $sym is not a (non-dependent) parameter in the system"
516
+ continue
517
+ end
518
+ idx in handled_idxs && continue
519
+ handle_parameter (ic, sym, idx, val)
520
+ push! (handled_idxs, idx)
521
+ elseif symbolic_type (idx) == ArraySymbolic ()
522
+ sym = idx
523
+ idx = parameter_index (ic, sym)
524
+ if idx === nothing
525
+ Symbolics. shape (sym) == Symbolics. Unknown () &&
526
+ throw (ParameterNotInSystem (sym))
527
+ size (sym) == size (val) || throw (InvalidParameterSizeException (sym, val))
528
+
529
+ for (i, vali) in zip (eachindex (sym), eachindex (val))
530
+ idx = parameter_index (ic, sym[i])
531
+ if idx === nothing
532
+ @warn " Symbolic variable $sym is not a (non-dependent) parameter in the system"
533
+ continue
534
+ end
535
+ # Intentionally don't check handled_idxs here because array variables always take priority
536
+ # See Issue#2804
537
+ handle_parameter (ic, sym[i], idx, val[vali])
538
+ push! (handled_idxs, idx)
539
+ end
540
+ else
541
+ idx in handled_idxs && continue
542
+ handle_parameter (ic, sym, idx, val)
543
+ push! (handled_idxs, idx)
545
544
end
545
+ else # NotSymbolic
546
+ if ! (idx isa ParameterIndex)
547
+ throw (ArgumentError (" Expected index for parameter to be a symbolic variable or `ParameterIndex`, got $idx " ))
548
+ end
549
+ handle_parameter (ic, nothing , idx, val)
546
550
end
547
551
end
548
552
@@ -688,7 +692,7 @@ function jacobian_wrt_vars(pf::F, p::MTKParameters, input_idxs, chunk::C) where
688
692
689
693
function (p_small_inner)
690
694
for (i, val) in zip (input_idxs, p_small_inner)
691
- _set_parameter_unchecked ! (p_big, val, i)
695
+ set_parameter ! (p_big, val, i)
692
696
end
693
697
return if pf isa SciMLBase. ParamJacobianWrapper
694
698
buffer = Array {dualtype} (undef, size (pf. u))
735
739
function ParameterTypeException (func, param, expected, val)
736
740
TypeError (func, " Parameter $param " , expected, val)
737
741
end
742
+
743
+ struct ParameterNotInSystem <: Exception
744
+ p:: Any
745
+ end
746
+
747
+ function Base. showerror (io:: IO , e:: ParameterNotInSystem )
748
+ println (io, " Symbolic variable $(e. p) is not a parameter in the system" )
749
+ end
0 commit comments