Skip to content

Commit de504b5

Browse files
refactor: add _set_parameter_unchecked!, use it for jacobian_wrt_params
1 parent 6398aa3 commit de504b5

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

src/systems/parameter_buffer.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,24 @@ function SymbolicIndexingInterface.set_parameter!(p::MTKParameters, val, idx::Pa
182182
end
183183
end
184184

185+
function _set_parameter_unchecked!(p::MTKParameters, val, idx::ParameterIndex)
186+
@unpack portion, idx = idx
187+
update_dependent = true
188+
if portion isa SciMLStructures.Tunable
189+
p.tunable[idx] = val
190+
elseif portion isa SciMLStructures.Discrete
191+
p.discrete[idx] = val
192+
elseif portion isa SciMLStructures.Constants
193+
p.constant[idx] = val
194+
elseif portion === nothing
195+
p.dependent[idx] = val
196+
update_dependent = false
197+
else
198+
error("Unhandled portion $portion")
199+
end
200+
update_dependent && p.dependent_update_iip !== nothing && p.dependent_update_iip(p.dependent, p...)
201+
end
202+
185203
_subarrays(v::AbstractVector) = isempty(v) ? () : (v,)
186204
_subarrays(v::ArrayPartition) = v.x
187205
_num_subarrays(v::AbstractVector) = 1
@@ -249,7 +267,7 @@ function jacobian_wrt_vars(pf::F, p::MTKParameters, input_idxs, chunk::C) where
249267

250268
function (p_small_inner)
251269
for (i, val) in zip(input_idxs, p_small_inner)
252-
set_parameter!(p_big, val, i)
270+
_set_parameter_unchecked!(p_big, val, i)
253271
end
254272
# tunable, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p_big)
255273
# tunable[input_idxs] .= p_small_inner

0 commit comments

Comments
 (0)