Skip to content

Commit 213246e

Browse files
Merge pull request #3157 from AayushSabharwal/as/interval-nlprob
feat: add support for `IntervalNonlinearProblem`
2 parents fa14fdd + 108d055 commit 213246e

File tree

3 files changed

+139
-6
lines changed

3 files changed

+139
-6
lines changed

src/ModelingToolkit.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ export JumpSystem
226226
export ODEProblem, SDEProblem
227227
export NonlinearFunction, NonlinearFunctionExpr
228228
export NonlinearProblem, NonlinearProblemExpr
229+
export IntervalNonlinearFunction, IntervalNonlinearFunctionExpr
230+
export IntervalNonlinearProblem, IntervalNonlinearProblemExpr
229231
export OptimizationProblem, OptimizationProblemExpr, constraints
230232
export SteadyStateProblem, SteadyStateProblemExpr
231233
export JumpProblem

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 115 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -258,13 +258,18 @@ end
258258

259259
function generate_function(
260260
sys::NonlinearSystem, dvs = unknowns(sys), ps = parameters(sys);
261-
wrap_code = identity, kwargs...)
261+
wrap_code = identity, scalar = false, kwargs...)
262262
rhss = [deq.rhs for deq in equations(sys)]
263+
dvs′ = value.(dvs)
264+
if scalar
265+
rhss = only(rhss)
266+
dvs′ = only(dvs)
267+
end
263268
pre, sol_states = get_substitutions_and_solved_unknowns(sys)
264269
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs, ps) .∘
265-
wrap_parameter_dependencies(sys, false)
270+
wrap_parameter_dependencies(sys, scalar)
266271
p = reorder_parameters(sys, value.(ps))
267-
return build_function(rhss, value.(dvs), p...; postprocess_fbody = pre,
272+
return build_function(rhss, dvs, p...; postprocess_fbody = pre,
268273
states = sol_states, wrap_code, kwargs...)
269274
end
270275

@@ -288,7 +293,7 @@ SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(sys),
288293
kwargs...) where {iip}
289294
```
290295
291-
Create an `NonlinearFunction` from the [`NonlinearSystem`](@ref). The arguments
296+
Create a `NonlinearFunction` from the [`NonlinearSystem`](@ref). The arguments
292297
`dvs` and `ps` are used to set the order of the dependent variable and parameter
293298
vectors, respectively.
294299
"""
@@ -351,6 +356,34 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s
351356
observed = observedfun)
352357
end
353358

359+
"""
360+
$(TYPEDSIGNATURES)
361+
362+
Create an `IntervalNonlinearFunction` from the [`NonlinearSystem`](@ref). The arguments
363+
`dvs` and `ps` are used to set the order of the dependent variable and parameter vectors,
364+
respectively.
365+
"""
366+
function SciMLBase.IntervalNonlinearFunction(
367+
sys::NonlinearSystem, dvs = unknowns(sys), ps = parameters(sys), u0 = nothing;
368+
p = nothing, eval_expression = false, eval_module = @__MODULE__, kwargs...)
369+
if !iscomplete(sys)
370+
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `IntervalNonlinearFunction`")
371+
end
372+
if !isone(length(dvs)) || !isone(length(equations(sys)))
373+
error("`IntervalNonlinearFunction` only supports systems with a single equation and a single unknown.")
374+
end
375+
376+
f_gen = generate_function(
377+
sys, dvs, ps; expression = Val{true}, scalar = true, kwargs...)
378+
f_oop = eval_or_rgf(f_gen; eval_expression, eval_module)
379+
f(u, p) = f_oop(u, p)
380+
f(u, p::MTKParameters) = f_oop(u, p...)
381+
382+
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
383+
384+
IntervalNonlinearFunction{false}(f; observed = observedfun, sys = sys)
385+
end
386+
354387
"""
355388
```julia
356389
SciMLBase.NonlinearFunctionExpr{iip}(sys::NonlinearSystem, dvs = unknowns(sys),
@@ -361,14 +394,14 @@ SciMLBase.NonlinearFunctionExpr{iip}(sys::NonlinearSystem, dvs = unknowns(sys),
361394
kwargs...) where {iip}
362395
```
363396
364-
Create a Julia expression for an `ODEFunction` from the [`ODESystem`](@ref).
397+
Create a Julia expression for a `NonlinearFunction` from the [`NonlinearSystem`](@ref).
365398
The arguments `dvs` and `ps` are used to set the order of the dependent
366399
variable and parameter vectors, respectively.
367400
"""
368401
struct NonlinearFunctionExpr{iip} end
369402

370403
function NonlinearFunctionExpr{iip}(sys::NonlinearSystem, dvs = unknowns(sys),
371-
ps = parameters(sys), u0 = nothing, p = nothing;
404+
ps = parameters(sys), u0 = nothing; p = nothing,
372405
version = nothing, tgrad = false,
373406
jac = false,
374407
linenumbers = false,
@@ -412,6 +445,34 @@ function NonlinearFunctionExpr{iip}(sys::NonlinearSystem, dvs = unknowns(sys),
412445
!linenumbers ? Base.remove_linenums!(ex) : ex
413446
end
414447

448+
"""
449+
$(TYPEDSIGNATURES)
450+
451+
Create a Julia expression for an `IntervalNonlinearFunction` from the
452+
[`NonlinearSystem`](@ref). The arguments `dvs` and `ps` are used to set the order of the
453+
dependent variable and parameter vectors, respectively.
454+
"""
455+
function IntervalNonlinearFunctionExpr(
456+
sys::NonlinearSystem, dvs = unknowns(sys), ps = parameters(sys),
457+
u0 = nothing; p = nothing, linenumbers = false, kwargs...)
458+
if !iscomplete(sys)
459+
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `IntervalNonlinearFunctionExpr`")
460+
end
461+
if !isone(length(dvs)) || !isone(length(equations(sys)))
462+
error("`IntervalNonlinearFunctionExpr` only supports systems with a single equation and a single unknown.")
463+
end
464+
465+
f = generate_function(sys, dvs, ps; expression = Val{true}, scalar = true, kwargs...)
466+
467+
IntervalNonlinearFunction{false}(f; sys = sys)
468+
469+
ex = quote
470+
f = $f
471+
NonlinearFunction{false}(f)
472+
end
473+
!linenumbers ? Base.remove_linenums!(ex) : ex
474+
end
475+
415476
"""
416477
```julia
417478
DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map,
@@ -470,6 +531,26 @@ function DiffEqBase.NonlinearLeastSquaresProblem{iip}(sys::NonlinearSystem, u0ma
470531
NonlinearLeastSquaresProblem{iip}(f, u0, p; filter_kwargs(kwargs)...)
471532
end
472533

534+
"""
535+
$(TYPEDSIGNATURES)
536+
537+
Generate an `IntervalNonlinearProblem` from a `NonlinearSystem` and allow for automatically
538+
symbolically calculating numerical enhancements.
539+
"""
540+
function DiffEqBase.IntervalNonlinearProblem(sys::NonlinearSystem, uspan::NTuple{2},
541+
parammap = SciMLBase.NullParameters(); kwargs...)
542+
if !iscomplete(sys)
543+
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `IntervalNonlinearProblem`")
544+
end
545+
if !isone(length(unknowns(sys))) || !isone(length(equations(sys)))
546+
error("`IntervalNonlinearProblem` only supports with a single equation and a single unknown.")
547+
end
548+
f, u0, p = process_SciMLProblem(
549+
IntervalNonlinearFunction, sys, unknowns(sys) .=> uspan[1], parammap; kwargs...)
550+
551+
return IntervalNonlinearProblem(f, uspan, p; filter_kwargs(kwargs)...)
552+
end
553+
473554
"""
474555
```julia
475556
DiffEqBase.NonlinearProblemExpr{iip}(sys::NonlinearSystem, u0map,
@@ -550,6 +631,34 @@ function NonlinearLeastSquaresProblemExpr{iip}(sys::NonlinearSystem, u0map,
550631
!linenumbers ? Base.remove_linenums!(ex) : ex
551632
end
552633

634+
"""
635+
$(TYPEDSIGNATURES)
636+
637+
Generates a Julia expression for an IntervalNonlinearProblem from a
638+
NonlinearSystem and allows for automatically symbolically calculating
639+
numerical enhancements.
640+
"""
641+
function IntervalNonlinearProblemExpr(sys::NonlinearSystem, uspan::NTuple{2},
642+
parammap = SciMLBase.NullParameters(); kwargs...)
643+
if !iscomplete(sys)
644+
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `IntervalNonlinearProblemExpr`")
645+
end
646+
if !isone(length(unknowns(sys))) || !isone(length(equations(sys)))
647+
error("`IntervalNonlinearProblemExpr` only supports with a single equation and a single unknown.")
648+
end
649+
f, u0, p = process_SciMLProblem(
650+
IntervalNonlinearFunctionExpr, sys, unknowns(sys) .=> uspan[1], parammap; kwargs...)
651+
linenumbers = get(kwargs, :linenumbers, true)
652+
653+
ex = quote
654+
f = $f
655+
uspan = $uspan
656+
p = $p
657+
IntervalNonlinearProblem(f, uspan, p; $(filter_kwargs(kwargs)...))
658+
end
659+
!linenumbers ? Base.remove_linenums!(ex) : ex
660+
end
661+
553662
function flatten(sys::NonlinearSystem, noeqs = false)
554663
systems = get_systems(sys)
555664
if isempty(systems)

test/nonlinearsystem.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,3 +358,25 @@ end
358358
@test_nowarn solve(prob)
359359
end
360360
end
361+
362+
@testset "IntervalNonlinearProblem" begin
363+
@variables x
364+
@parameters p
365+
@named nlsys = NonlinearSystem([0 ~ x * x - p])
366+
367+
for sys in [complete(nlsys), complete(nlsys; split = false)]
368+
prob = IntervalNonlinearProblem(sys, (0.0, 2.0), [p => 1.0])
369+
sol = @test_nowarn solve(prob, ITP())
370+
@test SciMLBase.successful_retcode(sol)
371+
@test_nowarn IntervalNonlinearProblemExpr(sys, (0.0, 2.0), [p => 1.0])
372+
end
373+
374+
@variables y
375+
@mtkbuild sys = NonlinearSystem([0 ~ x * x - p * x + p, 0 ~ x * y + p])
376+
@test_throws ["single equation", "unknown"] IntervalNonlinearProblem(sys, (0.0, 1.0))
377+
@test_throws ["single equation", "unknown"] IntervalNonlinearFunction(sys, (0.0, 1.0))
378+
@test_throws ["single equation", "unknown"] IntervalNonlinearProblemExpr(
379+
sys, (0.0, 1.0))
380+
@test_throws ["single equation", "unknown"] IntervalNonlinearFunctionExpr(
381+
sys, (0.0, 1.0))
382+
end

0 commit comments

Comments
 (0)