@@ -3,8 +3,9 @@ symconvert(::Type{T}, x) where {T} = convert(T, x)
33symconvert (:: Type{Real} , x:: Integer ) = convert (Float64, x)
44symconvert (:: Type{V} , x) where {V <: AbstractArray } = convert (V, symconvert .(eltype (V), x))
55
6- struct MTKParameters{T, D, C, N, H}
6+ struct MTKParameters{T, I, D, C, N, H}
77 tunable:: T
8+ initials:: I
89 discrete:: D
910 constant:: C
1011 nonnumeric:: N
@@ -65,6 +66,8 @@ function MTKParameters(
6566
6667 tunable_buffer = Vector {ic.tunable_buffer_size.type} (
6768 undef, ic. tunable_buffer_size. length)
69+ initials_buffer = Vector {ic.initials_buffer_size.type} (
70+ undef, ic. initials_buffer_size. length)
6871 disc_buffer = Tuple (BlockedArray (
6972 Vector {subbuffer_sizes[1].type} (
7073 undef, sum (x -> x. length, subbuffer_sizes)),
@@ -79,6 +82,9 @@ function MTKParameters(
7982 if haskey (ic. tunable_idx, sym)
8083 idx = ic. tunable_idx[sym]
8184 tunable_buffer[idx] = val
85+ elseif haskey (ic. initials_idx, sym)
86+ idx = ic. initials_idx[sym]
87+ initials_buffer[idx] = val
8288 elseif haskey (ic. discrete_idx, sym)
8389 idx = ic. discrete_idx[sym]
8490 disc_buffer[idx. buffer_idx][idx. idx_in_buffer] = val
@@ -124,15 +130,19 @@ function MTKParameters(
124130 if isempty (tunable_buffer)
125131 tunable_buffer = SizedVector {0, Float64} ()
126132 end
133+ initials_buffer = narrow_buffer_type (initials_buffer)
134+ if isempty (initials_buffer)
135+ initials_buffer = SizedVector {0, Float64} ()
136+ end
127137 disc_buffer = narrow_buffer_type .(disc_buffer)
128138 const_buffer = narrow_buffer_type .(const_buffer)
129139 # Don't narrow nonnumeric types
130140 nonnumeric_buffer = nonnumeric_buffer
131141
132142 mtkps = MTKParameters{
133- typeof (tunable_buffer), typeof (disc_buffer ), typeof (const_buffer ),
134- typeof (nonnumeric_buffer), typeof (())}(tunable_buffer,
135- disc_buffer, const_buffer, nonnumeric_buffer, ())
143+ typeof (tunable_buffer), typeof (initials_buffer ), typeof (disc_buffer ),
144+ typeof (const_buffer), typeof ( nonnumeric_buffer), typeof (())}(tunable_buffer,
145+ initials_buffer, disc_buffer, const_buffer, nonnumeric_buffer, ())
136146 return mtkps
137147end
138148
@@ -252,6 +262,26 @@ function SciMLStructures.replace!(::SciMLStructures.Tunable, p::MTKParameters, n
252262 return nothing
253263end
254264
265+ function SciMLStructures. canonicalize (:: SciMLStructures.Initials , p:: MTKParameters )
266+ arr = p. initials
267+ repack = let p = p
268+ function (new_val)
269+ return SciMLStructures. replace (SciMLStructures. Initials (), p, new_val)
270+ end
271+ end
272+ return arr, repack, true
273+ end
274+
275+ function SciMLStructures. replace (:: SciMLStructures.Initials , p:: MTKParameters , newvals)
276+ @set! p. initials = newvals
277+ return p
278+ end
279+
280+ function SciMLStructures. replace! (:: SciMLStructures.Initials , p:: MTKParameters , newvals)
281+ copyto! (p. initials, newvals)
282+ return nothing
283+ end
284+
255285for (Portion, field, recurse) in [(SciMLStructures. Discrete, :discrete , 1 )
256286 (SciMLStructures. Constants, :constant , 1 )
257287 (Nonnumeric, :nonnumeric , 1 )
@@ -279,12 +309,14 @@ end
279309
280310function Base. copy (p:: MTKParameters )
281311 tunable = copy (p. tunable)
312+ initials = copy (p. initials)
282313 discrete = Tuple (eltype (buf) <: Real ? copy (buf) : copy .(buf) for buf in p. discrete)
283314 constant = Tuple (eltype (buf) <: Real ? copy (buf) : copy .(buf) for buf in p. constant)
284315 nonnumeric = copy .(p. nonnumeric)
285316 caches = copy .(p. caches)
286317 return MTKParameters (
287318 tunable,
319+ initials,
288320 discrete,
289321 constant,
290322 nonnumeric,
@@ -300,6 +332,9 @@ function _ducktyped_parameter_values(p, pind::ParameterIndex)
300332 if portion isa SciMLStructures. Tunable
301333 return idx isa Int ? p. tunable[idx] : view (p. tunable, idx)
302334 end
335+ if portion isa SciMLStructures. Initials
336+ return idx isa Int ? p. initials[idx] : view (p. initials, idx)
337+ end
303338 i, j, k... = idx
304339 if portion isa SciMLStructures. Discrete
305340 return isempty (k) ? p. discrete[i][j] : p. discrete[i][j][k... ]
@@ -320,6 +355,11 @@ function SymbolicIndexingInterface.set_parameter!(
320355 throw (InvalidParameterSizeException (size (idx), size (val)))
321356 end
322357 p. tunable[idx] = val
358+ elseif portion isa SciMLStructures. Initials
359+ if validate_size && size (val) != = size (idx)
360+ throw (InvalidParameterSizeException (size (idx), size (val)))
361+ end
362+ p. initials[idx] = val
323363 else
324364 i, j, k... = idx
325365 if portion isa SciMLStructures. Discrete
394434
395435function validate_parameter_type (ic:: IndexCache , idx:: ParameterIndex , val)
396436 stype = get_buffer_template (ic, idx). type
397- if idx. portion == SciMLStructures. Tunable () && ! (idx. idx isa Int)
437+ if (idx. portion == SciMLStructures. Tunable () ||
438+ idx. portion == SciMLStructures. Initials ()) && ! (idx. idx isa Int)
398439 stype = AbstractArray{<: stype }
399440 end
400441 validate_parameter_type (
@@ -454,6 +495,7 @@ function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, id
454495end
455496function _remake_buffer (indp, oldbuf:: MTKParameters , idxs, vals; validate = true )
456497 newbuf = @set oldbuf. tunable = similar (oldbuf. tunable, Any)
498+ @set! newbuf. initials = similar (oldbuf. initials, Any)
457499 @set! newbuf. discrete = Tuple (similar (buf, Any) for buf in newbuf. discrete)
458500 @set! newbuf. constant = Tuple (similar (buf, Any) for buf in newbuf. constant)
459501 @set! newbuf. nonnumeric = Tuple (similar (buf, Any) for buf in newbuf. nonnumeric)
@@ -529,6 +571,12 @@ function _remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = true
529571 T = promote_type (eltype (newbuf. tunable), Float64)
530572 @set! newbuf. tunable = T .(newbuf. tunable)
531573 end
574+ @set! newbuf. initials = narrow_buffer_type_and_fallback_undefs (
575+ oldbuf. initials, newbuf. initials)
576+ if eltype (newbuf. initials) <: Integer
577+ T = promote_type (eltype (newbuf. initials), Float64)
578+ @set! newbuf. initials = T .(newbuf. initials)
579+ end
532580 @set! newbuf. discrete = narrow_buffer_type_and_fallback_undefs .(
533581 oldbuf. discrete, newbuf. discrete)
534582 @set! newbuf. constant = narrow_buffer_type_and_fallback_undefs .(
540588
541589function as_any_buffer (p:: MTKParameters )
542590 @set! p. tunable = similar (p. tunable, Any)
591+ @set! p. initials = similar (p. initials, Any)
543592 @set! p. discrete = Tuple (similar (buf, Any) for buf in p. discrete)
544593 @set! p. constant = Tuple (similar (buf, Any) for buf in p. constant)
545594 @set! p. nonnumeric = Tuple (similar (buf, Any) for buf in p. nonnumeric)
@@ -615,11 +664,14 @@ end
615664# getindex indexes the vectors, setindex! linearly indexes values
616665# it's inconsistent, but we need it to be this way
617666@generated function Base. getindex (
618- ps:: MTKParameters{T, D, C, N, H} , idx:: Int ) where {T, D, C, N, H}
667+ ps:: MTKParameters{T, I, D, C, N, H} , idx:: Int ) where {T, I , D, C, N, H}
619668 paths = []
620669 if ! (T <: SizedVector{0, Float64} )
621670 push! (paths, :(ps. tunable))
622671 end
672+ if ! (I <: SizedVector{0, Float64} )
673+ push! (paths, :(ps. initials))
674+ end
623675 for i in 1 : fieldcount (D)
624676 push! (paths, :(ps. discrete[$ i]))
625677 end
@@ -641,11 +693,15 @@ end
641693 return Expr (:block , expr, :(throw (BoundsError (ps, idx))))
642694end
643695
644- @generated function Base. length (ps:: MTKParameters{T, D, C, N, H} ) where {T, D, C, N, H}
696+ @generated function Base. length (ps:: MTKParameters {
697+ T, I, D, C, N, H}) where {T, I, D, C, N, H}
645698 len = 0
646699 if ! (T <: SizedVector{0, Float64} )
647700 len += 1
648701 end
702+ if ! (I <: SizedVector{0, Float64} )
703+ len += 1
704+ end
649705 len += fieldcount (D) + fieldcount (C) + fieldcount (N) + fieldcount (H)
650706 return len
651707end
@@ -668,7 +724,7 @@ function Base.iterate(buf::MTKParameters, state = 1)
668724end
669725
670726function Base.:(== )(a:: MTKParameters , b:: MTKParameters )
671- return a. tunable == b. tunable && a. discrete == b. discrete &&
727+ return a. tunable == b. tunable && a. initials == b . initials && a . discrete == b. discrete &&
672728 a. constant == b. constant && a. nonnumeric == b. nonnumeric &&
673729 all (Iterators. map (a. caches, b. caches) do acache, bcache
674730 eltype (acache) == eltype (bcache) && length (acache) == length (bcache)
0 commit comments