Skip to content

Commit b52ecc2

Browse files
feat: add Initials portion to MTKParameters
1 parent 5567ab2 commit b52ecc2

File tree

1 file changed

+64
-8
lines changed

1 file changed

+64
-8
lines changed

src/systems/parameter_buffer.jl

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ symconvert(::Type{T}, x) where {T} = convert(T, x)
33
symconvert(::Type{Real}, x::Integer) = convert(Float64, x)
44
symconvert(::Type{V}, x) where {V <: AbstractArray} = convert(V, symconvert.(eltype(V), x))
55

6-
struct MTKParameters{T, D, C, N, H}
6+
struct MTKParameters{T, I, D, C, N, H}
77
tunable::T
8+
initials::I
89
discrete::D
910
constant::C
1011
nonnumeric::N
@@ -65,6 +66,8 @@ function MTKParameters(
6566

6667
tunable_buffer = Vector{ic.tunable_buffer_size.type}(
6768
undef, ic.tunable_buffer_size.length)
69+
initials_buffer = Vector{ic.initials_buffer_size.type}(
70+
undef, ic.initials_buffer_size.length)
6871
disc_buffer = Tuple(BlockedArray(
6972
Vector{subbuffer_sizes[1].type}(
7073
undef, sum(x -> x.length, subbuffer_sizes)),
@@ -79,6 +82,9 @@ function MTKParameters(
7982
if haskey(ic.tunable_idx, sym)
8083
idx = ic.tunable_idx[sym]
8184
tunable_buffer[idx] = val
85+
elseif haskey(ic.initials_idx, sym)
86+
idx = ic.initials_idx[sym]
87+
initials_buffer[idx] = val
8288
elseif haskey(ic.discrete_idx, sym)
8389
idx = ic.discrete_idx[sym]
8490
disc_buffer[idx.buffer_idx][idx.idx_in_buffer] = val
@@ -124,15 +130,19 @@ function MTKParameters(
124130
if isempty(tunable_buffer)
125131
tunable_buffer = SizedVector{0, Float64}()
126132
end
133+
initials_buffer = narrow_buffer_type(initials_buffer)
134+
if isempty(initials_buffer)
135+
initials_buffer = SizedVector{0, Float64}()
136+
end
127137
disc_buffer = narrow_buffer_type.(disc_buffer)
128138
const_buffer = narrow_buffer_type.(const_buffer)
129139
# Don't narrow nonnumeric types
130140
nonnumeric_buffer = nonnumeric_buffer
131141

132142
mtkps = MTKParameters{
133-
typeof(tunable_buffer), typeof(disc_buffer), typeof(const_buffer),
134-
typeof(nonnumeric_buffer), typeof(())}(tunable_buffer,
135-
disc_buffer, const_buffer, nonnumeric_buffer, ())
143+
typeof(tunable_buffer), typeof(initials_buffer), typeof(disc_buffer),
144+
typeof(const_buffer), typeof(nonnumeric_buffer), typeof(())}(tunable_buffer,
145+
initials_buffer, disc_buffer, const_buffer, nonnumeric_buffer, ())
136146
return mtkps
137147
end
138148

@@ -252,6 +262,26 @@ function SciMLStructures.replace!(::SciMLStructures.Tunable, p::MTKParameters, n
252262
return nothing
253263
end
254264

265+
function SciMLStructures.canonicalize(::SciMLStructures.Initials, p::MTKParameters)
266+
arr = p.initials
267+
repack = let p = p
268+
function (new_val)
269+
return SciMLStructures.replace(SciMLStructures.Initials(), p, new_val)
270+
end
271+
end
272+
return arr, repack, true
273+
end
274+
275+
function SciMLStructures.replace(::SciMLStructures.Initials, p::MTKParameters, newvals)
276+
@set! p.initials = newvals
277+
return p
278+
end
279+
280+
function SciMLStructures.replace!(::SciMLStructures.Initials, p::MTKParameters, newvals)
281+
copyto!(p.initials, newvals)
282+
return nothing
283+
end
284+
255285
for (Portion, field, recurse) in [(SciMLStructures.Discrete, :discrete, 1)
256286
(SciMLStructures.Constants, :constant, 1)
257287
(Nonnumeric, :nonnumeric, 1)
@@ -279,12 +309,14 @@ end
279309

280310
function Base.copy(p::MTKParameters)
281311
tunable = copy(p.tunable)
312+
initials = copy(p.initials)
282313
discrete = Tuple(eltype(buf) <: Real ? copy(buf) : copy.(buf) for buf in p.discrete)
283314
constant = Tuple(eltype(buf) <: Real ? copy(buf) : copy.(buf) for buf in p.constant)
284315
nonnumeric = copy.(p.nonnumeric)
285316
caches = copy.(p.caches)
286317
return MTKParameters(
287318
tunable,
319+
initials,
288320
discrete,
289321
constant,
290322
nonnumeric,
@@ -300,6 +332,9 @@ function _ducktyped_parameter_values(p, pind::ParameterIndex)
300332
if portion isa SciMLStructures.Tunable
301333
return idx isa Int ? p.tunable[idx] : view(p.tunable, idx)
302334
end
335+
if portion isa SciMLStructures.Initials
336+
return idx isa Int ? p.initials[idx] : view(p.initials, idx)
337+
end
303338
i, j, k... = idx
304339
if portion isa SciMLStructures.Discrete
305340
return isempty(k) ? p.discrete[i][j] : p.discrete[i][j][k...]
@@ -320,6 +355,11 @@ function SymbolicIndexingInterface.set_parameter!(
320355
throw(InvalidParameterSizeException(size(idx), size(val)))
321356
end
322357
p.tunable[idx] = val
358+
elseif portion isa SciMLStructures.Initials
359+
if validate_size && size(val) !== size(idx)
360+
throw(InvalidParameterSizeException(size(idx), size(val)))
361+
end
362+
p.initials[idx] = val
323363
else
324364
i, j, k... = idx
325365
if portion isa SciMLStructures.Discrete
@@ -394,7 +434,8 @@ end
394434

395435
function validate_parameter_type(ic::IndexCache, idx::ParameterIndex, val)
396436
stype = get_buffer_template(ic, idx).type
397-
if idx.portion == SciMLStructures.Tunable() && !(idx.idx isa Int)
437+
if (idx.portion == SciMLStructures.Tunable() ||
438+
idx.portion == SciMLStructures.Initials()) && !(idx.idx isa Int)
398439
stype = AbstractArray{<:stype}
399440
end
400441
validate_parameter_type(
@@ -454,6 +495,7 @@ function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, id
454495
end
455496
function _remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = true)
456497
newbuf = @set oldbuf.tunable = similar(oldbuf.tunable, Any)
498+
@set! newbuf.initials = similar(oldbuf.initials, Any)
457499
@set! newbuf.discrete = Tuple(similar(buf, Any) for buf in newbuf.discrete)
458500
@set! newbuf.constant = Tuple(similar(buf, Any) for buf in newbuf.constant)
459501
@set! newbuf.nonnumeric = Tuple(similar(buf, Any) for buf in newbuf.nonnumeric)
@@ -529,6 +571,12 @@ function _remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = true
529571
T = promote_type(eltype(newbuf.tunable), Float64)
530572
@set! newbuf.tunable = T.(newbuf.tunable)
531573
end
574+
@set! newbuf.initials = narrow_buffer_type_and_fallback_undefs(
575+
oldbuf.initials, newbuf.initials)
576+
if eltype(newbuf.initials) <: Integer
577+
T = promote_type(eltype(newbuf.initials), Float64)
578+
@set! newbuf.initials = T.(newbuf.initials)
579+
end
532580
@set! newbuf.discrete = narrow_buffer_type_and_fallback_undefs.(
533581
oldbuf.discrete, newbuf.discrete)
534582
@set! newbuf.constant = narrow_buffer_type_and_fallback_undefs.(
@@ -540,6 +588,7 @@ end
540588

541589
function as_any_buffer(p::MTKParameters)
542590
@set! p.tunable = similar(p.tunable, Any)
591+
@set! p.initials = similar(p.initials, Any)
543592
@set! p.discrete = Tuple(similar(buf, Any) for buf in p.discrete)
544593
@set! p.constant = Tuple(similar(buf, Any) for buf in p.constant)
545594
@set! p.nonnumeric = Tuple(similar(buf, Any) for buf in p.nonnumeric)
@@ -615,11 +664,14 @@ end
615664
# getindex indexes the vectors, setindex! linearly indexes values
616665
# it's inconsistent, but we need it to be this way
617666
@generated function Base.getindex(
618-
ps::MTKParameters{T, D, C, N, H}, idx::Int) where {T, D, C, N, H}
667+
ps::MTKParameters{T, I, D, C, N, H}, idx::Int) where {T, I, D, C, N, H}
619668
paths = []
620669
if !(T <: SizedVector{0, Float64})
621670
push!(paths, :(ps.tunable))
622671
end
672+
if !(I <: SizedVector{0, Float64})
673+
push!(paths, :(ps.initials))
674+
end
623675
for i in 1:fieldcount(D)
624676
push!(paths, :(ps.discrete[$i]))
625677
end
@@ -641,11 +693,15 @@ end
641693
return Expr(:block, expr, :(throw(BoundsError(ps, idx))))
642694
end
643695

644-
@generated function Base.length(ps::MTKParameters{T, D, C, N, H}) where {T, D, C, N, H}
696+
@generated function Base.length(ps::MTKParameters{
697+
T, I, D, C, N, H}) where {T, I, D, C, N, H}
645698
len = 0
646699
if !(T <: SizedVector{0, Float64})
647700
len += 1
648701
end
702+
if !(I <: SizedVector{0, Float64})
703+
len += 1
704+
end
649705
len += fieldcount(D) + fieldcount(C) + fieldcount(N) + fieldcount(H)
650706
return len
651707
end
@@ -668,7 +724,7 @@ function Base.iterate(buf::MTKParameters, state = 1)
668724
end
669725

670726
function Base.:(==)(a::MTKParameters, b::MTKParameters)
671-
return a.tunable == b.tunable && a.discrete == b.discrete &&
727+
return a.tunable == b.tunable && a.initials == b.initials && a.discrete == b.discrete &&
672728
a.constant == b.constant && a.nonnumeric == b.nonnumeric &&
673729
all(Iterators.map(a.caches, b.caches) do acache, bcache
674730
eltype(acache) == eltype(bcache) && length(acache) == length(bcache)

0 commit comments

Comments
 (0)