Skip to content

Commit 1266976

Browse files
Merge pull request #3151 from AayushSabharwal/as/homotopy-rational-poly
feat: support rational functions in `HomotopyContinuationProblem`
2 parents 7ff50d1 + 7aae63d commit 1266976

File tree

4 files changed

+137
-17
lines changed

4 files changed

+137
-17
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ REPL = "1"
121121
RecursiveArrayTools = "3.26"
122122
Reexport = "0.2, 1"
123123
RuntimeGeneratedFunctions = "0.5.9"
124-
SciMLBase = "2.56.1"
124+
SciMLBase = "2.57.1"
125125
SciMLStructures = "1.0"
126126
Serialization = "1"
127127
Setfield = "0.7, 0.8, 1"

ext/MTKHomotopyContinuationExt.jl

Lines changed: 87 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module MTKHomotopyContinuationExt
22

33
using ModelingToolkit
44
using ModelingToolkit.SciMLBase
5-
using ModelingToolkit.Symbolics: unwrap, symtype
5+
using ModelingToolkit.Symbolics: unwrap, symtype, BasicSymbolic, simplify_fractions
66
using ModelingToolkit.SymbolicIndexingInterface
77
using ModelingToolkit.DocStringExtensions
88
using HomotopyContinuation
@@ -27,7 +27,7 @@ function is_polynomial(x, wrt)
2727
contains_variable(x, wrt) || return true
2828
any(isequal(x), wrt) && return true
2929

30-
if operation(x) in (*, +, -)
30+
if operation(x) in (*, +, -, /)
3131
return all(y -> is_polynomial(y, wrt), arguments(x))
3232
end
3333
if operation(x) == (^)
@@ -57,6 +57,57 @@ end
5757
"""
5858
$(TYPEDSIGNATURES)
5959
60+
Given a `x`, a polynomial in variables in `wrt` which may contain rational functions,
61+
express `x` as a single rational function with polynomial `num` and denominator `den`.
62+
Return `(num, den)`.
63+
"""
64+
function handle_rational_polynomials(x, wrt)
65+
x = unwrap(x)
66+
symbolic_type(x) == NotSymbolic() && return x, 1
67+
iscall(x) || return x, 1
68+
contains_variable(x, wrt) || return x, 1
69+
any(isequal(x), wrt) && return x, 1
70+
71+
# simplify_fractions cancels out some common factors
72+
# and expands (a / b)^c to a^c / b^c, so we only need
73+
# to handle these cases
74+
x = simplify_fractions(x)
75+
op = operation(x)
76+
args = arguments(x)
77+
78+
if op == /
79+
# numerator and denominator are trivial
80+
num, den = args
81+
# but also search for rational functions in numerator
82+
n, d = handle_rational_polynomials(num, wrt)
83+
num, den = n, den * d
84+
elseif op == +
85+
num = 0
86+
den = 1
87+
88+
# we don't need to do common denominator
89+
# because we don't care about cases where denominator
90+
# is zero. The expression is zero when all the numerators
91+
# are zero.
92+
for arg in args
93+
n, d = handle_rational_polynomials(arg, wrt)
94+
num += n
95+
den *= d
96+
end
97+
else
98+
return x, 1
99+
end
100+
# if the denominator isn't a polynomial in `wrt`, better to not include it
101+
# to reduce the size of the gcd polynomial
102+
if !contains_variable(den, wrt)
103+
return num / den, 1
104+
end
105+
return num, den
106+
end
107+
108+
"""
109+
$(TYPEDSIGNATURES)
110+
60111
Convert `expr` from a symbolics expression to one that uses `HomotopyContinuation.ModelKit`.
61112
"""
62113
function symbolics_to_hc(expr)
@@ -139,51 +190,74 @@ function MTK.HomotopyContinuationProblem(
139190
dvs = unknowns(sys)
140191
eqs = equations(sys)
141192

142-
for eq in eqs
193+
denoms = []
194+
eqs2 = map(eqs) do eq
143195
if !is_polynomial(eq.lhs, dvs) || !is_polynomial(eq.rhs, dvs)
144196
error("Equation $eq is not a polynomial in the unknowns. See warnings for further details.")
145197
end
198+
num, den = handle_rational_polynomials(eq.rhs - eq.lhs, dvs)
199+
push!(denoms, den)
200+
return 0 ~ num
146201
end
147202

148-
nlfn, u0, p = MTK.process_SciMLProblem(NonlinearFunction{true}, sys, u0map, parammap;
203+
sys2 = MTK.@set sys.eqs = eqs2
204+
205+
nlfn, u0, p = MTK.process_SciMLProblem(NonlinearFunction{true}, sys2, u0map, parammap;
149206
jac = true, eval_expression, eval_module)
150207

208+
denominator = MTK.build_explicit_observed_function(sys, denoms)
209+
151210
hvars = symbolics_to_hc.(dvs)
152211
mtkhsys = MTKHomotopySystem(nlfn.f, p, nlfn.jac, hvars, length(eqs))
153212

154213
obsfn = MTK.ObservedFunctionCache(sys; eval_expression, eval_module)
155214

156-
return MTK.HomotopyContinuationProblem(u0, mtkhsys, sys, obsfn)
215+
return MTK.HomotopyContinuationProblem(u0, mtkhsys, denominator, sys, obsfn)
157216
end
158217

159218
"""
160219
$(TYPEDSIGNATURES)
161220
162221
Solve a `HomotopyContinuationProblem`. Ignores the algorithm passed to it, and always
163-
uses `HomotopyContinuation.jl`. All keyword arguments are forwarded to
164-
`HomotopyContinuation.solve`. The original solution as returned by `HomotopyContinuation.jl`
165-
will be available in the `.original` field of the returned `NonlinearSolution`.
222+
uses `HomotopyContinuation.jl`. All keyword arguments except the ones listed below are
223+
forwarded to `HomotopyContinuation.solve`. The original solution as returned by
224+
`HomotopyContinuation.jl` will be available in the `.original` field of the returned
225+
`NonlinearSolution`.
166226
167227
All keyword arguments have their default values in HomotopyContinuation.jl, except
168228
`show_progress` which defaults to `false`.
229+
230+
Extra keyword arguments:
231+
- `denominator_abstol`: In case `prob` is solving a rational function, roots which cause
232+
the denominator to be below `denominator_abstol` will be discarded.
169233
"""
170234
function CommonSolve.solve(prob::MTK.HomotopyContinuationProblem,
171-
alg = nothing; show_progress = false, kwargs...)
235+
alg = nothing; show_progress = false, denominator_abstol = 1e-8, kwargs...)
172236
sol = HomotopyContinuation.solve(
173237
prob.homotopy_continuation_system; show_progress, kwargs...)
174238
realsols = HomotopyContinuation.results(sol; only_real = true)
175239
if isempty(realsols)
176240
u = state_values(prob)
177-
resid = prob.homotopy_continuation_system(u)
178241
retcode = SciMLBase.ReturnCode.ConvergenceFailure
179242
else
243+
T = eltype(state_values(prob))
180244
distance, idx = findmin(realsols) do result
245+
if any(<=(denominator_abstol),
246+
prob.denominator(real.(result.solution), parameter_values(prob)))
247+
return T(Inf)
248+
end
181249
norm(result.solution - state_values(prob))
182250
end
183-
u = real.(realsols[idx].solution)
184-
resid = prob.homotopy_continuation_system(u)
185-
retcode = SciMLBase.ReturnCode.Success
251+
# all roots cause denominator to be zero
252+
if isinf(distance)
253+
u = state_values(prob)
254+
retcode = SciMLBase.ReturnCode.Infeasible
255+
else
256+
u = real.(realsols[idx].solution)
257+
retcode = SciMLBase.ReturnCode.Success
258+
end
186259
end
260+
resid = prob.homotopy_continuation_system(u)
187261

188262
return SciMLBase.build_solution(
189263
prob, :HomotopyContinuation, u, resid; retcode, original = sol)

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ A type of Nonlinear problem which specializes on polynomial systems and uses
573573
HomotopyContinuation.jl to solve the system. Requires importing HomotopyContinuation.jl to
574574
create and solve.
575575
"""
576-
struct HomotopyContinuationProblem{uType, H, O} <:
576+
struct HomotopyContinuationProblem{uType, H, D, O} <:
577577
SciMLBase.AbstractNonlinearProblem{uType, true}
578578
"""
579579
The initial values of states in the system. If there are multiple real roots of
@@ -586,6 +586,12 @@ struct HomotopyContinuationProblem{uType, H, O} <:
586586
"""
587587
homotopy_continuation_system::H
588588
"""
589+
A function with signature `(u, p) -> resid`. In case of rational functions, this
590+
is used to rule out roots of the system which would cause the denominator to be
591+
zero.
592+
"""
593+
denominator::D
594+
"""
589595
The `NonlinearSystem` used to create this problem. Used for symbolic indexing.
590596
"""
591597
sys::NonlinearSystem

test/extensions/homotopy_continuation.jl

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,47 @@ end
8282
@mtkbuild sys = NonlinearSystem([x^x - x ~ 0])
8383
@test_warn ["Exponent", "unknowns"] @test_throws "not a polynomial" HomotopyContinuationProblem(
8484
sys, [])
85-
@mtkbuild sys = NonlinearSystem([((x^2) / (x + 3))^2 + x ~ 0])
86-
@test_warn ["Base", "not a polynomial", "Unrecognized operation", "/"] @test_throws "not a polynomial" HomotopyContinuationProblem(
85+
@mtkbuild sys = NonlinearSystem([((x^2) / sin(x))^2 + x ~ 0])
86+
@test_warn ["Unrecognized", "sin"] @test_throws "not a polynomial" HomotopyContinuationProblem(
8787
sys, [])
8888
end
89+
90+
@testset "Rational functions" begin
91+
@variables x=2.0 y=2.0
92+
@parameters n = 4
93+
@mtkbuild sys = NonlinearSystem([
94+
0 ~ (x^2 - n * x + n) * (x - 1) / (x - 2) / (x - 3)
95+
])
96+
prob = HomotopyContinuationProblem(sys, [])
97+
sol = solve(prob; threading = false)
98+
@test sol[x] 1.0
99+
p = parameter_values(prob)
100+
for invalid in [2.0, 3.0]
101+
@test prob.denominator([invalid], p)[1] <= 1e-8
102+
end
103+
104+
@named sys = NonlinearSystem(
105+
[
106+
0 ~ (x - 2) / (x - 4) + ((x - 3) / (y - 7)) / ((x^2 - 4x + y) / (x - 2.5)),
107+
0 ~ ((y - 3) / (y - 4)) * (n / (y - 5)) + ((x - 1.5) / (x - 5.5))^2
108+
],
109+
[x, y],
110+
[n])
111+
sys = complete(sys)
112+
prob = HomotopyContinuationProblem(sys, [])
113+
sol = solve(prob; threading = false)
114+
disallowed_x = [4, 5.5]
115+
disallowed_y = [7, 5, 4]
116+
@test all(!isapprox(sol[x]; atol = 1e-8), disallowed_x)
117+
@test all(!isapprox(sol[y]; atol = 1e-8), disallowed_y)
118+
@test sol[x^2 - 4x + y] >= 1e-8
119+
120+
p = parameter_values(prob)
121+
for val in disallowed_x
122+
@test any(<=(1e-8), prob.denominator([val, 2.0], p))
123+
end
124+
for val in disallowed_y
125+
@test any(<=(1e-8), prob.denominator([2.0, val], p))
126+
end
127+
@test prob.denominator([2.0, 4.0], p)[1] <= 1e-8
128+
end

0 commit comments

Comments
 (0)