Skip to content

Commit 8933782

Browse files
Handle modelingtoolkitize for nonlinearleastsquaresproblem
Fixes #2669
1 parent 622408b commit 8933782

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

src/systems/nonlinear/modelingtoolkitize.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ $(TYPEDSIGNATURES)
44
Generate `NonlinearSystem`, dependent variables, and parameters from an `NonlinearProblem`.
55
"""
66
function modelingtoolkitize(
7-
prob::NonlinearProblem; u_names = nothing, p_names = nothing, kwargs...)
7+
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem};
8+
u_names = nothing, p_names = nothing, kwargs...)
89
p = prob.p
910
has_p = !(p isa Union{DiffEqBase.NullParameters, Nothing})
1011

@@ -37,13 +38,18 @@ function modelingtoolkitize(
3738
end
3839

3940
if DiffEqBase.isinplace(prob)
40-
rhs = ArrayInterface.restructure(prob.u0, similar(vars, Num))
41+
if prob isa NonlinearLeastSquaresProblem
42+
rhs = ArrayInterface.restructure(prob.f.resid_prototype, similar(prob.f.resid_prototype, Num))
43+
else
44+
rhs = ArrayInterface.restructure(prob.u0, similar(vars, Num))
45+
end
4146
prob.f(rhs, vars, params)
47+
eqs = vcat([0.0 ~ rhs[i] for i in 1:length(prob.f.resid_prototype)]...)
4248
else
4349
rhs = prob.f(vars, params)
50+
out_def = prob.f(prob.u0, prob.p)
51+
eqs = vcat([0.0 ~ rhs[i] for i in 1:length(out_def)]...)
4452
end
45-
out_def = prob.f(prob.u0, prob.p)
46-
eqs = vcat([0.0 ~ rhs[i] for i in 1:length(out_def)]...)
4753

4854
sts = vec(collect(vars))
4955
_params = params

test/modelingtoolkitize.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,3 +473,17 @@ sys = modelingtoolkitize(prob)
473473
end
474474
end
475475
end
476+
477+
## NonlinearLeastSquaresProblem
478+
479+
function nlls!(du, u, p)
480+
du[1] = 2u[1] - 2
481+
du[2] = u[1] - 4u[2]
482+
du[3] = 0
483+
end
484+
u0 = [0.0, 0.0]
485+
prob = NonlinearLeastSquaresProblem(
486+
NonlinearFunction(nlls!, resid_prototype = zeros(3)), u0)
487+
sys = modelingtoolkitize(prob)
488+
@test length(equations(sys)) == 3
489+
@test length(equations(structural_simplify(sys; fully_determined = false))) == 0

0 commit comments

Comments
 (0)