Skip to content

Commit 9000837

Browse files
Merge pull request #3006 from AayushSabharwal/as/fix-linearization
fix: fix missing `MTKParameters` handling in `linearize`
2 parents 4f524e8 + b9df681 commit 9000837

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

src/systems/abstractsystem.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2285,7 +2285,7 @@ function linearization_function(sys::AbstractSystem, inputs,
22852285

22862286
function (u, p, t)
22872287
p_setter!(oldps, p_getter(u, p..., t))
2288-
newu = u_getter(u, p, t)
2288+
newu = u_getter(u, p..., t)
22892289
return newu, oldps
22902290
end
22912291
end
@@ -2303,6 +2303,13 @@ function linearization_function(sys::AbstractSystem, inputs,
23032303
initfn = NonlinearFunction(initsys; eval_expression, eval_module)
23042304
initprobmap = build_explicit_observed_function(
23052305
initsys, unknowns(sys); eval_expression, eval_module)
2306+
if has_index_cache(sys) && get_index_cache(sys) !== nothing
2307+
initprobmap = let inner = initprobmap
2308+
fn(u, p::MTKParameters) = inner(u, p...)
2309+
fn(u, p) = inner(u, p)
2310+
fn
2311+
end
2312+
end
23062313
ps = parameters(sys)
23072314
h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
23082315
lin_fun = let diff_idxs = diff_idxs,
@@ -2342,7 +2349,7 @@ function linearization_function(sys::AbstractSystem, inputs,
23422349
initu0, initp = get_initprob_u_p(u, p, t)
23432350
initprob = NonlinearLeastSquaresProblem(initfn, initu0, initp)
23442351
nlsol = solve(initprob, initialization_solver_alg)
2345-
u = initprobmap(nlsol)
2352+
u = initprobmap(state_values(nlsol), parameter_values(nlsol))
23462353
end
23472354
end
23482355
uf = SciMLBase.UJacobianWrapper(fun, t, p)

test/downstream/linearize.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,3 +306,13 @@ linfun, _ = linearization_function(sys, [u], []; op = Dict(x => 2.0))
306306
matrices = linfun([1.0], Dict(p => 3.0), 1.0)
307307
# this would be 1 if the parameter value isn't respected
308308
@test matrices.f_u[] == 3.0
309+
310+
@testset "Issue #2941" begin
311+
@variables x(t) y(t)
312+
@parameters p
313+
eqs = [0 ~ x * log(y) - p]
314+
@named sys = ODESystem(eqs, t; defaults = [p => 1.0])
315+
sys = complete(sys)
316+
@test_nowarn linearize(
317+
sys, [x], []; op = Dict(x => 1.0), allow_input_derivatives = true)
318+
end

0 commit comments

Comments
 (0)