diff --git a/Project.toml b/Project.toml index 79b499748c..1e96c0e4c0 100644 --- a/Project.toml +++ b/Project.toml @@ -41,7 +41,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8" @@ -49,7 +48,6 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" -SCCNonlinearSolve = "9dfe8606-65a1-4bb3-9748-cb89d1561431" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLPublic = "431bcebd-1456-4ced-9d72-93c2757fff0b" SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" @@ -73,7 +71,9 @@ DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6" FMI = "14a09403-18e3-468f-ad8a-74f8dda2d9ac" InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" +NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" Pyomo = "0e8e1daf-01b5-4eba-a626-3897743a3816" +SCCNonlinearSolve = "9dfe8606-65a1-4bb3-9748-cb89d1561431" [extensions] MTKBifurcationKitExt = "BifurcationKit" @@ -82,6 +82,7 @@ MTKDeepDiffsExt = "DeepDiffs" MTKFMIExt = "FMI" MTKInfiniteOptExt = "InfiniteOpt" MTKLabelledArraysExt = "LabelledArrays" +MTKNonlinearSolveExt = ["NonlinearSolve", "SCCNonlinearSolve"] MTKPyomoDynamicOptExt = "Pyomo" [compat] diff --git a/ext/MTKNonlinearSolveExt/MTKNonlinearSolveExt.jl b/ext/MTKNonlinearSolveExt/MTKNonlinearSolveExt.jl new file mode 100644 index 0000000000..62e6ef27ac --- /dev/null +++ b/ext/MTKNonlinearSolveExt/MTKNonlinearSolveExt.jl @@ -0,0 +1,10 @@ +module MTKNonlinearSolveExt + +using ModelingToolkit +using NonlinearSolve +using SCCNonlinearSolve + +# Export the TrustRegion algorithm for use in linearization +ModelingToolkit._get_default_nlsolve_alg() = TrustRegion() + +end \ No newline at end of file diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 2c259058b0..f7b7188f1d 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -52,8 +52,7 @@ import JuliaFormatter using MLStyle import Moshi using Moshi.Data: @data -using NonlinearSolve -import SCCNonlinearSolve +# NonlinearSolve and SCCNonlinearSolve are now loaded via extension using ImplicitDiscreteSolve using Reexport using RecursiveArrayTools @@ -379,4 +378,7 @@ PrecompileTools.@compile_workload begin end end +# Default nonlinear solver algorithm - will be overridden by extension when NonlinearSolve is loaded +_get_default_nlsolve_alg() = error("NonlinearSolve.jl is required for linearization with initialization. Please load NonlinearSolve.jl to use this functionality.") + end # module diff --git a/src/linearization.jl b/src/linearization.jl index f2d73f6bee..9030832876 100644 --- a/src/linearization.jl +++ b/src/linearization.jl @@ -1,5 +1,5 @@ """ - lin_fun, simplified_sys = linearization_function(sys::AbstractSystem, inputs, outputs; simplify = false, initialize = true, initialization_solver_alg = TrustRegion(), kwargs...) + lin_fun, simplified_sys = linearization_function(sys::AbstractSystem, inputs, outputs; simplify = false, initialize = true, initialization_solver_alg = nothing, kwargs...) Return a function that linearizes the system `sys`. The function [`linearize`](@ref) provides a higher-level and easier to use interface. @@ -24,7 +24,7 @@ The `simplified_sys` has undergone [`mtkcompile`](@ref) and had any occurring in - `outputs`: A vector of variables that indicate the outputs of the linearized input-output model. - `simplify`: Apply simplification in tearing. - `initialize`: If true, a check is performed to ensure that the operating point is consistent (satisfies algebraic equations). If the op is not consistent, initialization is performed. - - `initialization_solver_alg`: A NonlinearSolve algorithm to use for solving for a feasible set of state and algebraic variables that satisfies the specified operating point. + - `initialization_solver_alg`: A NonlinearSolve algorithm to use for solving for a feasible set of state and algebraic variables that satisfies the specified operating point. If `nothing` (default), a default algorithm will be used when NonlinearSolve.jl is loaded. - `autodiff`: An `ADType` supported by DifferentiationInterface.jl to use for calculating the necessary jacobians. Defaults to using `AutoForwardDiff()` - `kwargs`: Are passed on to `find_solvables!` @@ -39,7 +39,7 @@ function linearization_function(sys::AbstractSystem, inputs, op = Dict(), p = DiffEqBase.NullParameters(), zero_dummy_der = false, - initialization_solver_alg = TrustRegion(), + initialization_solver_alg = nothing, autodiff = AutoForwardDiff(), eval_expression = false, eval_module = @__MODULE__, warn_initialize_determined = true, @@ -81,9 +81,16 @@ function linearization_function(sys::AbstractSystem, inputs, ps = parameters(sys) h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module) + # Use default algorithm if none provided and initialization is enabled + actual_solver_alg = if initialization_solver_alg === nothing && initialize + _get_default_nlsolve_alg() + else + initialization_solver_alg + end + initialization_kwargs = (; abstol = initialization_abstol, reltol = initialization_reltol, - nlsolve_alg = initialization_solver_alg) + nlsolve_alg = actual_solver_alg) p = parameter_values(prob) t0 = current_time(prob)