Skip to content

Commit 1218152

Browse files
feat: add copy method for MTKParameters
1 parent 3e0aea0 commit 1218152

File tree

1 file changed

+30
-7
lines changed

1 file changed

+30
-7
lines changed

src/systems/parameter_buffer.jl

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals
3939
for (sym, _) in p
4040
if istree(sym) && operation(sym) === getindex &&
4141
is_parameter(sys, arguments(sym)[begin])
42-
# error("Scalarized parameter values are not supported. Instead of `[p[1] => 1.0, p[2] => 2.0]` use `[p => [1.0, 2.0]]`")
42+
error("Scalarized parameter values are not supported. Instead of `[p[1] => 1.0, p[2] => 2.0]` use `[p => [1.0, 2.0]]`")
4343
end
4444
end
4545

@@ -121,7 +121,7 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals
121121
end
122122

123123
function buffer_to_arraypartition(buf)
124-
return ArrayPartition((eltype(v) isa AbstractArray ? buffer_to_arraypartition(v) : v for v in buf)...)
124+
return ArrayPartition(Tuple(eltype(v) <: AbstractArray ? buffer_to_arraypartition(v) : v for v in buf))
125125
end
126126

127127
function split_into_buffers(raw::AbstractArray, buf; recurse = true)
@@ -146,13 +146,19 @@ for (Portion, field) in [(SciMLStructures.Tunable, :tunable)
146146
(SciMLStructures.Discrete, :discrete)
147147
(SciMLStructures.Constants, :constant)]
148148
@eval function SciMLStructures.canonicalize(::$Portion, p::MTKParameters)
149-
function repack(_) # aliases, so we don't need to use the parameter
150-
if p.dependent_update_iip !== nothing
151-
p.dependent_update_iip(ArrayPartition(p.dependent), p...)
149+
as_vector = buffer_to_arraypartition(p.$field)
150+
repack = let as_vector = as_vector, p = p
151+
function (new_val)
152+
if new_val !== as_vector
153+
p.$field = split_into_buffers(new_val, p.$field)
154+
end
155+
if p.dependent_update_iip !== nothing
156+
p.dependent_update_iip(ArrayPartition(p.dependent), p...)
157+
end
158+
p
152159
end
153-
p
154160
end
155-
return buffer_to_arraypartition(p.$field), repack, true
161+
return as_vector, repack, true
156162
end
157163

158164
@eval function SciMLStructures.replace(::$Portion, p::MTKParameters, newvals)
@@ -176,6 +182,23 @@ for (Portion, field) in [(SciMLStructures.Tunable, :tunable)
176182
end
177183
end
178184

185+
function Base.copy(p::MTKParameters)
186+
tunable = Tuple(eltype(buf) <: Real ? copy(buf) : copy.(buf) for buf in p.tunable)
187+
discrete = Tuple(eltype(buf) <: Real ? copy(buf) : copy.(buf) for buf in p.discrete)
188+
constant = Tuple(eltype(buf) <: Real ? copy(buf) : copy.(buf) for buf in p.constant)
189+
dependent = Tuple(eltype(buf) <: Real ? copy(buf) : copy.(buf) for buf in p.dependent)
190+
nonnumeric = copy.(p.nonnumeric)
191+
return MTKParameters(
192+
tunable,
193+
discrete,
194+
constant,
195+
dependent,
196+
nonnumeric,
197+
p.dependent_update_iip,
198+
p.dependent_update_oop,
199+
)
200+
end
201+
179202
function SymbolicIndexingInterface.parameter_values(p::MTKParameters, pind::ParameterIndex)
180203
@unpack portion, idx = pind
181204
i, j, k... = idx

0 commit comments

Comments
 (0)