Skip to content

Commit 99fc9da

Browse files
refactor: don't splat MTKParameters into observed functions anymore
1 parent aa6acce commit 99fc9da

File tree

3 files changed

+12
-45
lines changed

3 files changed

+12
-45
lines changed

src/systems/abstractsystem.jl

Lines changed: 10 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -681,27 +681,12 @@ function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym)
681681
rawobs = build_explicit_observed_function(
682682
sys, sym; param_only = true, return_inplace = true)
683683
if rawobs isa Tuple
684-
if is_time_dependent(sys)
685-
obsfn = let oop = rawobs[1], iip = rawobs[2]
686-
f1a(p::MTKParameters, t) = oop(p..., t)
687-
f1a(out, p::MTKParameters, t) = iip(out, p..., t)
688-
end
689-
else
690-
obsfn = let oop = rawobs[1], iip = rawobs[2]
691-
f1b(p::MTKParameters) = oop(p...)
692-
f1b(out, p::MTKParameters) = iip(out, p...)
693-
end
684+
obsfn = let oop = rawobs[1], iip = rawobs[2]
685+
f1a(p, t) = oop(p, t)
686+
f1a(out, p, t) = iip(out, p, t)
694687
end
695688
else
696-
if is_time_dependent(sys)
697-
obsfn = let rawobs = rawobs
698-
f2a(p::MTKParameters, t) = rawobs(p..., t)
699-
end
700-
else
701-
obsfn = let rawobs = rawobs
702-
f2b(p::MTKParameters) = rawobs(p...)
703-
end
704-
end
689+
obsfn = rawobs
705690
end
706691
else
707692
obsfn = build_explicit_observed_function(sys, sym; param_only = true)
@@ -816,21 +801,11 @@ function SymbolicIndexingInterface.observed(
816801
_fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module)
817802

818803
if is_time_dependent(sys)
819-
return let _fn = _fn
820-
fn1(u, p, t) = _fn(u, p, t)
821-
fn1(u, p::MTKParameters, t) = _fn(u, p..., t)
822-
823-
# DDEs
824-
fn1(u, histfn, p, t) = _fn(u, histfn, p, t)
825-
fn1(u, histfn, p::MTKParameters, t) = _fn(u, histfn, p..., t)
826-
fn1
827-
end
804+
return _fn
828805
else
829806
return let _fn = _fn
830807
fn2(u, p) = _fn(u, p)
831-
fn2(u, p::MTKParameters) = _fn(u, p...)
832808
fn2(::Nothing, p) = _fn([], p)
833-
fn2(::Nothing, p::MTKParameters) = _fn([], p...)
834809
fn2
835810
end
836811
end
@@ -2368,8 +2343,8 @@ function linearization_function(sys::AbstractSystem, inputs,
23682343
u_getter = u_getter
23692344

23702345
function (u, p, t)
2371-
p_setter!(oldps, p_getter(u, p..., t))
2372-
newu = u_getter(u, p..., t)
2346+
p_setter!(oldps, p_getter(u, p, t))
2347+
newu = u_getter(u, p, t)
23732348
return newu, oldps
23742349
end
23752350
end
@@ -2380,20 +2355,13 @@ function linearization_function(sys::AbstractSystem, inputs,
23802355

23812356
function (u, p, t)
23822357
state = ProblemState(; u, p, t)
2383-
return u_getter(state), p_getter(state)
2358+
return u_getter(state_values(state), parameter_values(state), current_time(state)), p_getter(state)
23842359
end
23852360
end
23862361
end
23872362
initfn = NonlinearFunction(initsys; eval_expression, eval_module)
23882363
initprobmap = build_explicit_observed_function(
23892364
initsys, unknowns(sys); eval_expression, eval_module)
2390-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
2391-
initprobmap = let inner = initprobmap
2392-
fn(u, p::MTKParameters) = inner(u, p...)
2393-
fn(u, p) = inner(u, p)
2394-
fn
2395-
end
2396-
end
23972365
ps = parameters(sys)
23982366
h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
23992367
lin_fun = let diff_idxs = diff_idxs,
@@ -2440,7 +2408,7 @@ function linearization_function(sys::AbstractSystem, inputs,
24402408
fg_xz = ForwardDiff.jacobian(uf, u)
24412409
h_xz = ForwardDiff.jacobian(
24422410
let p = p, t = t
2443-
xz -> p isa MTKParameters ? h(xz, p..., t) : h(xz, p, t)
2411+
xz -> h(xz, p, t)
24442412
end, u)
24452413
pf = SciMLBase.ParamJacobianWrapper(fun, t, u)
24462414
fg_u = jacobian_wrt_vars(pf, p, input_idxs, chunk)
@@ -2452,7 +2420,6 @@ function linearization_function(sys::AbstractSystem, inputs,
24522420
end
24532421
hp = let u = u, t = t
24542422
_hp(p) = h(u, p, t)
2455-
_hp(p::MTKParameters) = h(u, p..., t)
24562423
_hp
24572424
end
24582425
h_u = jacobian_wrt_vars(hp, p, input_idxs, chunk)
@@ -2505,7 +2472,7 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
25052472
dx = fun(sts, p..., t)
25062473

25072474
h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
2508-
y = h(sts, p..., t)
2475+
y = h(sts, p, t)
25092476

25102477
fg_xz = Symbolics.jacobian(dx, sts)
25112478
fg_u = Symbolics.jacobian(dx, inputs)

test/clock.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ eqs = [yd ~ Sample(dt)(y)
514514

515515
@test sol.prob.kwargs[:disc_saved_values][1].t == sol.t[1:2:end] # Test that the discrete-time system executed at every step of the continuous solver. The solver saves each time step twice, one state value before discrete affect and one after.
516516
@test_nowarn ModelingToolkit.build_explicit_observed_function(
517-
model, model.counter.ud)(sol.u[1], prob.p..., sol.t[1])
517+
model, model.counter.ud)(sol.u[1], prob.p, sol.t[1])
518518

519519
@variables x(t)=1.0 y(t)=1.0
520520
eqs = [D(y) ~ Hold(x)

test/odesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@ prob = ODEProblem(
548548
@test_nowarn solve(prob, Tsit5())
549549
obsfn = ModelingToolkit.build_explicit_observed_function(
550550
outersys, bar(3outersys.sys.ms, 3outersys.sys.p))
551-
@test_nowarn obsfn(sol.u[1], prob.p..., sol.t[1])
551+
@test_nowarn obsfn(sol.u[1], prob.p, sol.t[1])
552552

553553
# x/x
554554
@variables x(t)

0 commit comments

Comments
 (0)