Skip to content

Commit 683a845

Browse files
committed
Merge remote-tracking branch 'origin/master' into add_odes_to_jumpsys
2 parents 2c95e19 + b52bce7 commit 683a845

15 files changed

+247
-57
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
2121
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
2222
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
2323
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
24+
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
2425
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
2526
Expronicon = "6b7a57c9-7cc1-4fdf-b7f5-e857abae3636"
2627
FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
@@ -94,6 +95,7 @@ Distributions = "0.23, 0.24, 0.25"
9495
DocStringExtensions = "0.7, 0.8, 0.9"
9596
DomainSets = "0.6, 0.7"
9697
DynamicQuantities = "^0.11.2, 0.12, 0.13, 1"
98+
EnumX = "1.0.4"
9799
ExprTools = "0.1.10"
98100
Expronicon = "0.8"
99101
FindFirstFunctions = "1"
@@ -111,7 +113,7 @@ Libdl = "1"
111113
LinearAlgebra = "1"
112114
MLStyle = "0.4.17"
113115
NaNMath = "0.3, 1"
114-
NonlinearSolve = "3.14"
116+
NonlinearSolve = "3.14, 4"
115117
OffsetArrays = "1"
116118
OrderedCollections = "1"
117119
OrdinaryDiffEq = "6.82.0"
@@ -125,7 +127,7 @@ SciMLBase = "2.57.1"
125127
SciMLStructures = "1.0"
126128
Serialization = "1"
127129
Setfield = "0.7, 0.8, 1"
128-
SimpleNonlinearSolve = "0.1.0, 1"
130+
SimpleNonlinearSolve = "0.1.0, 1, 2"
129131
SparseArrays = "1"
130132
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
131133
StaticArrays = "0.10, 0.11, 0.12, 1.0"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Distributions = "0.25"
3030
Documenter = "1"
3131
DynamicQuantities = "^0.11.2, 0.12, 1"
3232
ModelingToolkit = "8.33, 9"
33-
NonlinearSolve = "3"
33+
NonlinearSolve = "3, 4"
3434
Optim = "1.7"
3535
Optimization = "3.9"
3636
OptimizationOptimJL = "0.1"

docs/src/basics/MTKLanguage.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,9 @@ julia> ModelingToolkit.getdefault(model_c1.v)
147147
2.0
148148
```
149149

150-
#### `@extend` begin block
150+
#### `@extend` statement
151151

152-
Partial systems can be extended in a higher system in two ways:
152+
One or more partial systems can be extended in a higher system with `@extend` statements. This can be done in two ways:
153153

154154
- `@extend PartialSystem(var1 = value1)`
155155

@@ -313,7 +313,8 @@ end
313313
- `:components`: The list of sub-components in the form of [[name, sub_component_name],...].
314314
- `:constants`: Dictionary of constants mapped to its metadata.
315315
- `:defaults`: Dictionary of variables and default values specified in the `@defaults`.
316-
- `:extend`: The list of extended unknowns, name given to the base system, and name of the base system.
316+
- `:extend`: The list of extended unknowns, parameters and components, name given to the base system, and name of the base system.
317+
When multiple extend statements are present, latter two are returned as lists.
317318
- `:structural_parameters`: Dictionary of structural parameters mapped to their metadata.
318319
- `:parameters`: Dictionary of symbolic parameters mapped to their metadata. For
319320
parameter arrays, length is added to the metadata as `:size`.

docs/src/tutorials/stochastic_diffeq.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ where the magnitude of the noise scales with (0.3 times) the magnitude of each o
2323

2424
```math
2525
\begin{aligned}
26-
\frac{dx}{dt} &= (\sigma (y-x)) &+ 0.1x\frac{dB}{dt} \\
27-
\frac{dy}{dt} &= (x(\rho-z) - y) &+ 0.1y\frac{dB}{dt} \\
28-
\frac{dz}{dt} &= (xy - \beta z) &+ 0.1z\frac{dB}{dt} \\
26+
\frac{dx}{dt} &= (\sigma (y-x)) &+ 0.3x\frac{dB}{dt} \\
27+
\frac{dy}{dt} &= (x(\rho-z) - y) &+ 0.3y\frac{dB}{dt} \\
28+
\frac{dz}{dt} &= (xy - \beta z) &+ 0.3z\frac{dB}{dt} \\
2929
\end{aligned}
3030
```
3131

ext/MTKHomotopyContinuationExt.jl

Lines changed: 126 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,42 +15,104 @@ function contains_variable(x, wrt)
1515
any(y -> occursin(y, x), wrt)
1616
end
1717

18+
"""
19+
Possible reasons why a term is not polynomial
20+
"""
21+
MTK.EnumX.@enumx NonPolynomialReason begin
22+
NonIntegerExponent
23+
ExponentContainsUnknowns
24+
BaseNotPolynomial
25+
UnrecognizedOperation
26+
end
27+
28+
function display_reason(reason::NonPolynomialReason.T, sym)
29+
if reason == NonPolynomialReason.NonIntegerExponent
30+
pow = arguments(sym)[2]
31+
"In $sym: Exponent $pow is not an integer"
32+
elseif reason == NonPolynomialReason.ExponentContainsUnknowns
33+
pow = arguments(sym)[2]
34+
"In $sym: Exponent $pow contains unknowns of the system"
35+
elseif reason == NonPolynomialReason.BaseNotPolynomial
36+
base = arguments(sym)[1]
37+
"In $sym: Base $base is not a polynomial in the unknowns"
38+
elseif reason == NonPolynomialReason.UnrecognizedOperation
39+
op = operation(sym)
40+
"""
41+
In $sym: Operation $op is not recognized. Allowed polynomial operations are \
42+
`*, /, +, -, ^`.
43+
"""
44+
else
45+
error("This should never happen. Please open an issue in ModelingToolkit.jl.")
46+
end
47+
end
48+
49+
mutable struct PolynomialData
50+
non_polynomial_terms::Vector{BasicSymbolic}
51+
reasons::Vector{NonPolynomialReason.T}
52+
has_parametric_exponent::Bool
53+
end
54+
55+
PolynomialData() = PolynomialData(BasicSymbolic[], NonPolynomialReason.T[], false)
56+
57+
struct NotPolynomialError <: Exception
58+
eq::Equation
59+
data::PolynomialData
60+
end
61+
62+
function Base.showerror(io::IO, err::NotPolynomialError)
63+
println(io,
64+
"Equation $(err.eq) is not a polynomial in the unknowns for the following reasons:")
65+
for (term, reason) in zip(err.data.non_polynomial_terms, err.data.reasons)
66+
println(io, display_reason(reason, term))
67+
end
68+
end
69+
70+
function is_polynomial!(data, y, wrt)
71+
process_polynomial!(data, y, wrt)
72+
isempty(data.reasons)
73+
end
74+
1875
"""
1976
$(TYPEDSIGNATURES)
2077
21-
Check if `x` is polynomial with respect to the variables in `wrt`.
78+
Return information about the polynmial `x` with respect to variables in `wrt`,
79+
writing said information to `data`.
2280
"""
23-
function is_polynomial(x, wrt)
81+
function process_polynomial!(data::PolynomialData, x, wrt)
2482
x = unwrap(x)
2583
symbolic_type(x) == NotSymbolic() && return true
2684
iscall(x) || return true
2785
contains_variable(x, wrt) || return true
2886
any(isequal(x), wrt) && return true
2987

3088
if operation(x) in (*, +, -, /)
31-
return all(y -> is_polynomial(y, wrt), arguments(x))
89+
return all(y -> is_polynomial!(data, y, wrt), arguments(x))
3290
end
3391
if operation(x) == (^)
3492
b, p = arguments(x)
3593
is_pow_integer = symtype(p) <: Integer
3694
if !is_pow_integer
37-
if symbolic_type(p) == NotSymbolic()
38-
@warn "In $x: Exponent $p is not an integer"
39-
else
40-
@warn "In $x: Exponent $p is not an integer. Use `@parameters p::Integer` to declare integer parameters."
41-
end
95+
push!(data.non_polynomial_terms, x)
96+
push!(data.reasons, NonPolynomialReason.NonIntegerExponent)
97+
end
98+
if symbolic_type(p) != NotSymbolic()
99+
data.has_parametric_exponent = true
42100
end
101+
43102
exponent_has_unknowns = contains_variable(p, wrt)
44103
if exponent_has_unknowns
45-
@warn "In $x: Exponent $p cannot contain unknowns of the system."
104+
push!(data.non_polynomial_terms, x)
105+
push!(data.reasons, NonPolynomialReason.ExponentContainsUnknowns)
46106
end
47-
base_polynomial = is_polynomial(b, wrt)
107+
base_polynomial = is_polynomial!(data, b, wrt)
48108
if !base_polynomial
49-
@warn "In $x: Base is not a polynomial"
109+
push!(data.non_polynomial_terms, x)
110+
push!(data.reasons, NonPolynomialReason.BaseNotPolynomial)
50111
end
51112
return base_polynomial && !exponent_has_unknowns && is_pow_integer
52113
end
53-
@warn "In $x: Unrecognized operation $(operation(x)). Allowed polynomial operations are `*, +, -, ^`"
114+
push!(data.non_polynomial_terms, x)
115+
push!(data.reasons, NonPolynomialReason.UnrecognizedOperation)
54116
return false
55117
end
56118

@@ -179,21 +241,39 @@ Create a `HomotopyContinuationProblem` from a `NonlinearSystem` with polynomial
179241
The problem will be solved by HomotopyContinuation.jl. The resultant `NonlinearSolution`
180242
will contain the polynomial root closest to the point specified by `u0map` (if real roots
181243
exist for the system).
244+
245+
Keyword arguments:
246+
- `eval_expression`: Whether to `eval` the generated functions or use a `RuntimeGeneratedFunction`.
247+
- `eval_module`: The module to use for `eval`/`@RuntimeGeneratedFunction`
248+
- `warn_parametric_exponent`: Whether to warn if the system contains a parametric
249+
exponent preventing the homotopy from being cached.
250+
251+
All other keyword arguments are forwarded to `HomotopyContinuation.solver_startsystems`.
182252
"""
183253
function MTK.HomotopyContinuationProblem(
184254
sys::NonlinearSystem, u0map, parammap = nothing; eval_expression = false,
185-
eval_module = ModelingToolkit, kwargs...)
255+
eval_module = ModelingToolkit, warn_parametric_exponent = true, kwargs...)
186256
if !iscomplete(sys)
187257
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `HomotopyContinuationProblem`")
188258
end
189259

190260
dvs = unknowns(sys)
191-
eqs = equations(sys)
261+
# we need to consider `full_equations` because observed also should be
262+
# polynomials (if used in equations) and we don't know if observed is used
263+
# in denominator.
264+
# This is not the most efficient, and would be improved significantly with
265+
# CSE/hashconsing.
266+
eqs = full_equations(sys)
192267

193268
denoms = []
269+
has_parametric_exponents = false
194270
eqs2 = map(eqs) do eq
195-
if !is_polynomial(eq.lhs, dvs) || !is_polynomial(eq.rhs, dvs)
196-
error("Equation $eq is not a polynomial in the unknowns. See warnings for further details.")
271+
data = PolynomialData()
272+
process_polynomial!(data, eq.lhs, dvs)
273+
process_polynomial!(data, eq.rhs, dvs)
274+
has_parametric_exponents |= data.has_parametric_exponent
275+
if !isempty(data.non_polynomial_terms)
276+
throw(NotPolynomialError(eq, data))
197277
end
198278
num, den = handle_rational_polynomials(eq.rhs - eq.lhs, dvs)
199279

@@ -212,6 +292,9 @@ function MTK.HomotopyContinuationProblem(
212292
end
213293

214294
sys2 = MTK.@set sys.eqs = eqs2
295+
# remove observed equations to avoid adding them in codegen
296+
MTK.@set! sys2.observed = Equation[]
297+
MTK.@set! sys2.substitutions = nothing
215298

216299
nlfn, u0, p = MTK.process_SciMLProblem(NonlinearFunction{true}, sys2, u0map, parammap;
217300
jac = true, eval_expression, eval_module)
@@ -223,29 +306,49 @@ function MTK.HomotopyContinuationProblem(
223306

224307
obsfn = MTK.ObservedFunctionCache(sys; eval_expression, eval_module)
225308

226-
return MTK.HomotopyContinuationProblem(u0, mtkhsys, denominator, sys, obsfn)
309+
if has_parametric_exponents
310+
if warn_parametric_exponent
311+
@warn """
312+
The system has parametric exponents, preventing caching of the homotopy. \
313+
This will cause `solve` to be slower. Pass `warn_parametric_exponent \
314+
= false` to turn off this warning
315+
"""
316+
end
317+
solver_and_starts = nothing
318+
else
319+
solver_and_starts = HomotopyContinuation.solver_startsolutions(mtkhsys; kwargs...)
320+
end
321+
return MTK.HomotopyContinuationProblem(
322+
u0, mtkhsys, denominator, sys, obsfn, solver_and_starts)
227323
end
228324

229325
"""
230326
$(TYPEDSIGNATURES)
231327
232328
Solve a `HomotopyContinuationProblem`. Ignores the algorithm passed to it, and always
233-
uses `HomotopyContinuation.jl`. All keyword arguments except the ones listed below are
234-
forwarded to `HomotopyContinuation.solve`. The original solution as returned by
329+
uses `HomotopyContinuation.jl`. The original solution as returned by
235330
`HomotopyContinuation.jl` will be available in the `.original` field of the returned
236331
`NonlinearSolution`.
237332
238-
All keyword arguments have their default values in HomotopyContinuation.jl, except
239-
`show_progress` which defaults to `false`.
333+
All keyword arguments except the ones listed below are forwarded to
334+
`HomotopyContinuation.solve`. Note that the solver and start solutions are precomputed,
335+
and only keyword arguments related to the solve process are valid. All keyword
336+
arguments have their default values in HomotopyContinuation.jl, except `show_progress`
337+
which defaults to `false`.
240338
241339
Extra keyword arguments:
242340
- `denominator_abstol`: In case `prob` is solving a rational function, roots which cause
243341
the denominator to be below `denominator_abstol` will be discarded.
244342
"""
245343
function CommonSolve.solve(prob::MTK.HomotopyContinuationProblem,
246344
alg = nothing; show_progress = false, denominator_abstol = 1e-7, kwargs...)
247-
sol = HomotopyContinuation.solve(
248-
prob.homotopy_continuation_system; show_progress, kwargs...)
345+
if prob.solver_and_starts === nothing
346+
sol = HomotopyContinuation.solve(
347+
prob.homotopy_continuation_system; show_progress, kwargs...)
348+
else
349+
solver, starts = prob.solver_and_starts
350+
sol = HomotopyContinuation.solve(solver, starts; show_progress, kwargs...)
351+
end
249352
realsols = HomotopyContinuation.results(sol; only_real = true)
250353
if isempty(realsols)
251354
u = state_values(prob)

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ using RecursiveArrayTools
5555
import Graphs: SimpleDiGraph, add_edge!, incidence_matrix
5656
import BlockArrays: BlockedArray, Block, blocksize, blocksizes
5757
import CommonSolve
58+
import EnumX
5859

5960
using RuntimeGeneratedFunctions
6061
using RuntimeGeneratedFunctions: drop_expr

src/systems/abstractsystem.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -918,7 +918,7 @@ Mark a system as completed. A completed system is a system which is done being
918918
defined/modified and is ready for structural analysis or other transformations.
919919
This allows for analyses and optimizations to be performed which require knowing
920920
the global structure of the system.
921-
921+
922922
One property to note is that if a system is complete, the system will no longer
923923
namespace its subsystems or variables, i.e. `isequal(complete(sys).v.i, v.i)`.
924924
"""
@@ -1933,7 +1933,7 @@ function Base.show(
19331933
end
19341934
end
19351935
limited = nrows < nsubs
1936-
limited && print(io, "\n") # too many to print
1936+
limited && print(io, "\n") # too many to print
19371937

19381938
# Print equations
19391939
eqs = equations(sys)
@@ -3043,10 +3043,19 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem;
30433043
return T(args...; kwargs...)
30443044
end
30453045

3046+
function extend(sys, basesys::Vector{T}) where {T <: AbstractSystem}
3047+
foldl(extend, basesys, init = sys)
3048+
end
3049+
30463050
function Base.:(&)(sys::AbstractSystem, basesys::AbstractSystem; kwargs...)
30473051
extend(sys, basesys; kwargs...)
30483052
end
30493053

3054+
function Base.:(&)(
3055+
sys::AbstractSystem, basesys::Vector{T}; kwargs...) where {T <: AbstractSystem}
3056+
extend(sys, basesys; kwargs...)
3057+
end
3058+
30503059
"""
30513060
$(SIGNATURES)
30523061

src/systems/callbacks.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ end
552552
"""
553553
compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps; expression, kwargs...)
554554
555-
Returns a function `condition(u,p,t)` returning the `condition(cb)`.
555+
Returns a function `condition(u,t,integrator)` returning the `condition(cb)`.
556556
557557
Notes
558558
@@ -573,7 +573,8 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
573573
end
574574
expr = build_function(
575575
condit, u, t, p...; expression = Val{true},
576-
wrap_code = condition_header(sys) .∘ wrap_array_vars(sys, condit; dvs, ps) .∘
576+
wrap_code = condition_header(sys) .∘
577+
wrap_array_vars(sys, condit; dvs, ps, inputs = true) .∘
577578
wrap_parameter_dependencies(sys, !(condit isa AbstractArray)),
578579
kwargs...)
579580
if expression == Val{true}

0 commit comments

Comments
 (0)