Skip to content

Commit d5912be

Browse files
refactor: don't splat MTKParameters into observed functions anymore
1 parent ab639eb commit d5912be

File tree

6 files changed

+19
-43
lines changed

6 files changed

+19
-43
lines changed

src/systems/abstractsystem.jl

Lines changed: 13 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -695,25 +695,17 @@ function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym)
695695
if rawobs isa Tuple
696696
if is_time_dependent(sys)
697697
obsfn = let oop = rawobs[1], iip = rawobs[2]
698-
f1a(p::MTKParameters, t) = oop(p..., t)
699-
f1a(out, p::MTKParameters, t) = iip(out, p..., t)
698+
f1a(p, t) = oop(p, t)
699+
f1a(out, p, t) = iip(out, p, t)
700700
end
701701
else
702702
obsfn = let oop = rawobs[1], iip = rawobs[2]
703-
f1b(p::MTKParameters) = oop(p...)
704-
f1b(out, p::MTKParameters) = iip(out, p...)
703+
f1b(p) = oop(p)
704+
f1b(out, p) = iip(out, p)
705705
end
706706
end
707707
else
708-
if is_time_dependent(sys)
709-
obsfn = let rawobs = rawobs
710-
f2a(p::MTKParameters, t) = rawobs(p..., t)
711-
end
712-
else
713-
obsfn = let rawobs = rawobs
714-
f2b(p::MTKParameters) = rawobs(p...)
715-
end
716-
end
708+
obsfn = rawobs
717709
end
718710
else
719711
obsfn = build_explicit_observed_function(sys, sym; param_only = true)
@@ -828,21 +820,11 @@ function SymbolicIndexingInterface.observed(
828820
_fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module)
829821

830822
if is_time_dependent(sys)
831-
return let _fn = _fn
832-
fn1(u, p, t) = _fn(u, p, t)
833-
fn1(u, p::MTKParameters, t) = _fn(u, p..., t)
834-
835-
# DDEs
836-
fn1(u, histfn, p, t) = _fn(u, histfn, p, t)
837-
fn1(u, histfn, p::MTKParameters, t) = _fn(u, histfn, p..., t)
838-
fn1
839-
end
823+
return _fn
840824
else
841825
return let _fn = _fn
842826
fn2(u, p) = _fn(u, p)
843-
fn2(u, p::MTKParameters) = _fn(u, p...)
844827
fn2(::Nothing, p) = _fn([], p)
845-
fn2(::Nothing, p::MTKParameters) = _fn([], p...)
846828
fn2
847829
end
848830
end
@@ -2380,8 +2362,8 @@ function linearization_function(sys::AbstractSystem, inputs,
23802362
u_getter = u_getter
23812363

23822364
function (u, p, t)
2383-
p_setter!(oldps, p_getter(u, p..., t))
2384-
newu = u_getter(u, p..., t)
2365+
p_setter!(oldps, p_getter(u, p, t))
2366+
newu = u_getter(u, p, t)
23852367
return newu, oldps
23862368
end
23872369
end
@@ -2392,20 +2374,15 @@ function linearization_function(sys::AbstractSystem, inputs,
23922374

23932375
function (u, p, t)
23942376
state = ProblemState(; u, p, t)
2395-
return u_getter(state), p_getter(state)
2377+
return u_getter(
2378+
state_values(state), parameter_values(state), current_time(state)),
2379+
p_getter(state)
23962380
end
23972381
end
23982382
end
23992383
initfn = NonlinearFunction(initsys; eval_expression, eval_module)
24002384
initprobmap = build_explicit_observed_function(
24012385
initsys, unknowns(sys); eval_expression, eval_module)
2402-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
2403-
initprobmap = let inner = initprobmap
2404-
fn(u, p::MTKParameters) = inner(u, p...)
2405-
fn(u, p) = inner(u, p)
2406-
fn
2407-
end
2408-
end
24092386
ps = parameters(sys)
24102387
h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
24112388
lin_fun = let diff_idxs = diff_idxs,
@@ -2452,7 +2429,7 @@ function linearization_function(sys::AbstractSystem, inputs,
24522429
fg_xz = ForwardDiff.jacobian(uf, u)
24532430
h_xz = ForwardDiff.jacobian(
24542431
let p = p, t = t
2455-
xz -> p isa MTKParameters ? h(xz, p..., t) : h(xz, p, t)
2432+
xz -> h(xz, p, t)
24562433
end, u)
24572434
pf = SciMLBase.ParamJacobianWrapper(fun, t, u)
24582435
fg_u = jacobian_wrt_vars(pf, p, input_idxs, chunk)
@@ -2464,7 +2441,6 @@ function linearization_function(sys::AbstractSystem, inputs,
24642441
end
24652442
hp = let u = u, t = t
24662443
_hp(p) = h(u, p, t)
2467-
_hp(p::MTKParameters) = h(u, p..., t)
24682444
_hp
24692445
end
24702446
h_u = jacobian_wrt_vars(hp, p, input_idxs, chunk)
@@ -2517,7 +2493,7 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
25172493
dx = fun(sts, p..., t)
25182494

25192495
h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
2520-
y = h(sts, p..., t)
2496+
y = h(sts, p, t)
25212497

25222498
fg_xz = Symbolics.jacobian(dx, sts)
25232499
fg_u = Symbolics.jacobian(dx, inputs)

test/clock.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ eqs = [yd ~ Sample(dt)(y)
514514

515515
@test sol.prob.kwargs[:disc_saved_values][1].t == sol.t[1:2:end] # Test that the discrete-time system executed at every step of the continuous solver. The solver saves each time step twice, one state value before discrete affect and one after.
516516
@test_nowarn ModelingToolkit.build_explicit_observed_function(
517-
model, model.counter.ud)(sol.u[1], prob.p..., sol.t[1])
517+
model, model.counter.ud)(sol.u[1], prob.p, sol.t[1])
518518

519519
@variables x(t)=1.0 y(t)=1.0
520520
eqs = [D(y) ~ Hold(x)

test/input_output_handling.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,9 @@ if VERSION >= v"1.8" # :opaque_closure not supported before
144144
drop_expr = identity)
145145
x = randn(size(A, 1))
146146
u = randn(size(B, 2))
147-
p = getindex.(
147+
p = (getindex.(
148148
Ref(ModelingToolkit.defaults_and_guesses(ssys)),
149-
parameters(ssys))
149+
parameters(ssys)),)
150150
y1 = obsf(x, u, p, 0)
151151
y2 = C * x + D * u
152152
@test y1[] y2[]

test/odesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@ prob = ODEProblem(
548548
@test_nowarn solve(prob, Tsit5())
549549
obsfn = ModelingToolkit.build_explicit_observed_function(
550550
outersys, bar(3outersys.sys.ms, 3outersys.sys.p))
551-
@test_nowarn obsfn(sol.u[1], prob.p..., sol.t[1])
551+
@test_nowarn obsfn(sol.u[1], prob.p, sol.t[1])
552552

553553
# x/x
554554
@variables x(t)

test/reduction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ prob1 = ODEProblem(reduced_system, u0, (0.0, 100.0), pp)
119119
solve(prob1, Rodas5())
120120

121121
prob2 = SteadyStateProblem(reduced_system, u0, pp)
122-
@test prob2.f.observed(lorenz2.u, prob2.u0, pp) === 1.0
122+
@test prob2.f.observed(lorenz2.u, prob2.u0, prob2.p) === 1.0
123123

124124
# issue #724 and #716
125125
let

test/serialization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ for var in all_obs
5050
f = ModelingToolkit.build_explicit_observed_function(ss, var; expression = true)
5151
sym = ModelingToolkit.getname(var) |> string
5252
ex = :(if name == Symbol($sym)
53-
return $f(u0, p..., t)
53+
return $f(u0, p, t)
5454
end)
5555
push!(obs_exps, ex)
5656
end

0 commit comments

Comments
 (0)