5454
5555PolynomialData () = PolynomialData (BasicSymbolic[], NonPolynomialReason. T[], false )
5656
57+ abstract type PolynomialTransformationError <: Exception end
58+
59+ struct MultivarTerm <: PolynomialTransformationError
60+ term:: Any
61+ vars:: Any
62+ end
63+
64+ function Base. showerror (io:: IO , err:: MultivarTerm )
65+ println (io,
66+ " Cannot convert system to polynomial: Found term $(err. term) which is a function of multiple unknowns $(err. vars) ." )
67+ end
68+
69+ struct MultipleTermsOfSameVar <: PolynomialTransformationError
70+ terms:: Any
71+ var:: Any
72+ end
73+
74+ function Base. showerror (io:: IO , err:: MultipleTermsOfSameVar )
75+ println (io,
76+ " Cannot convert system to polynomial: Found multiple non-polynomial terms $(err. terms) involving the same unknown $(err. var) ." )
77+ end
78+
79+ struct SymbolicSolveFailure <: PolynomialTransformationError
80+ term:: Any
81+ var:: Any
82+ end
83+
84+ function Base. showerror (io:: IO , err:: SymbolicSolveFailure )
85+ println (io,
86+ " Cannot convert system to polynomial: Unable to symbolically solve $(err. term) for $(err. var) ." )
87+ end
88+
89+ struct NemoNotLoaded <: PolynomialTransformationError end
90+
91+ function Base. showerror (io:: IO , err:: NemoNotLoaded )
92+ println (io,
93+ " ModelingToolkit may be able to solve this system as a polynomial system if `Nemo` is loaded. Run `import Nemo` and try again." )
94+ end
95+
96+ struct VariablesAsPolyAndNonPoly <: PolynomialTransformationError
97+ vars:: Any
98+ end
99+
100+ function Base. showerror (io:: IO , err:: VariablesAsPolyAndNonPoly )
101+ println (io,
102+ " Cannot convert convert system to polynomial: Variables $(err. vars) occur in both polynomial and non-polynomial terms in the system." )
103+ end
104+
57105struct NotPolynomialError <: Exception
58- eq:: Equation
59- data:: PolynomialData
106+ transformation_err:: Union{PolynomialTransformationError, Nothing}
107+ eq:: Vector{Equation}
108+ data:: Vector{PolynomialData}
60109end
61110
62111function Base. showerror (io:: IO , err:: NotPolynomialError )
63- println (io,
64- " Equation $(err. eq) is not a polynomial in the unknowns for the following reasons:" )
65- for (term, reason) in zip (err. data. non_polynomial_terms, err. data. reasons)
66- println (io, display_reason (reason, term))
112+ if err. transformation_err != = nothing
113+ Base. showerror (io, err. transformation_err)
114+ end
115+ for (eq, data) in zip (err. eq, err. data)
116+ if isempty (data. non_polynomial_terms)
117+ continue
118+ end
119+ println (io,
120+ " Equation $(eq) is not a polynomial in the unknowns for the following reasons:" )
121+ for (term, reason) in zip (data. non_polynomial_terms, data. reasons)
122+ println (io, display_reason (reason, term))
123+ end
67124 end
68125end
69126
@@ -86,7 +143,9 @@ function process_polynomial!(data::PolynomialData, x, wrt)
86143 any (isequal (x), wrt) && return true
87144
88145 if operation (x) in (* , + , - , / )
89- return all (y -> is_polynomial! (data, y, wrt), arguments (x))
146+ # `map` because `all` will early exit, but we want to search
147+ # through everything to get all the non-polynomial terms
148+ return all (map (y -> is_polynomial! (data, y, wrt), arguments (x)))
90149 end
91150 if operation (x) == (^ )
92151 b, p = arguments (x)
@@ -105,10 +164,6 @@ function process_polynomial!(data::PolynomialData, x, wrt)
105164 push! (data. reasons, NonPolynomialReason. ExponentContainsUnknowns)
106165 end
107166 base_polynomial = is_polynomial! (data, b, wrt)
108- if ! base_polynomial
109- push! (data. non_polynomial_terms, x)
110- push! (data. reasons, NonPolynomialReason. BaseNotPolynomial)
111- end
112167 return base_polynomial && ! exponent_has_unknowns && is_pow_integer
113168 end
114169 push! (data. non_polynomial_terms, x)
234289
235290SymbolicIndexingInterface. parameter_values (s:: MTKHomotopySystem ) = s. p
236291
292+ struct PolynomialTransformationData
293+ new_var:: BasicSymbolic
294+ term:: BasicSymbolic
295+ inv_term:: Vector
296+ end
297+
237298"""
238299 $(TYPEDSIGNATURES)
239300
@@ -265,18 +326,95 @@ function MTK.HomotopyContinuationProblem(
265326 # CSE/hashconsing.
266327 eqs = full_equations (sys)
267328
268- denoms = []
269- has_parametric_exponents = false
270- eqs2 = map (eqs) do eq
329+ polydata = map (eqs) do eq
271330 data = PolynomialData ()
272331 process_polynomial! (data, eq. lhs, dvs)
273332 process_polynomial! (data, eq. rhs, dvs)
274- has_parametric_exponents |= data. has_parametric_exponent
275- if ! isempty (data. non_polynomial_terms)
276- throw (NotPolynomialError (eq, data))
333+ data
334+ end
335+
336+ has_parametric_exponents = any (d -> d. has_parametric_exponent, polydata)
337+
338+ all_non_poly_terms = mapreduce (d -> d. non_polynomial_terms, vcat, polydata)
339+ unique! (all_non_poly_terms)
340+
341+ var_to_nonpoly = Dict {BasicSymbolic, PolynomialTransformationData} ()
342+
343+ is_poly = true
344+ transformation_err = nothing
345+ for t in all_non_poly_terms
346+ # if the term involves multiple unknowns, we can't invert it
347+ dvs_in_term = map (x -> occursin (x, t), dvs)
348+ if count (dvs_in_term) > 1
349+ transformation_err = MultivarTerm (t, dvs[dvs_in_term])
350+ is_poly = false
351+ break
352+ end
353+ # we already have a substitution solving for `var`
354+ var = dvs[findfirst (dvs_in_term)]
355+ if haskey (var_to_nonpoly, var) && ! isequal (var_to_nonpoly[var]. term, t)
356+ transformation_err = MultipleTermsOfSameVar ([t, var_to_nonpoly[var]. term], var)
357+ is_poly = false
358+ break
359+ end
360+ # we want to solve `term - new_var` for `var`
361+ new_var = gensym (Symbol (var))
362+ new_var = unwrap (only (@variables $ new_var))
363+ invterm = Symbolics. ia_solve (
364+ t - new_var, var; complex_roots = false , periodic_roots = false , warns = false )
365+ # if we can't invert it, quit
366+ if invterm === nothing || isempty (invterm)
367+ transformation_err = SymbolicSolveFailure (t, var)
368+ is_poly = false
369+ break
370+ end
371+ # `ia_solve` returns lazy terms i.e. `asin(1.0)` instead of `pi/2`
372+ # this just evaluates the constant expressions
373+ invterm = Symbolics. substitute .(invterm, (Dict (),))
374+ # RootsOf implies Symbolics couldn't solve the inner polynomial because
375+ # `Nemo` wasn't loaded.
376+ if any (x -> MTK. iscall (x) && MTK. operation (x) == Symbolics. RootsOf, invterm)
377+ transformation_err = NemoNotLoaded ()
378+ is_poly = false
379+ break
277380 end
278- num, den = handle_rational_polynomials (eq. rhs - eq. lhs, dvs)
381+ var_to_nonpoly[var] = PolynomialTransformationData (new_var, t, invterm)
382+ end
383+
384+ if ! is_poly
385+ throw (NotPolynomialError (transformation_err, eqs, polydata))
386+ end
387+
388+ subrules = Dict ()
389+ combinations = Vector[]
390+ new_dvs = []
391+ for x in dvs
392+ if haskey (var_to_nonpoly, x)
393+ _data = var_to_nonpoly[x]
394+ subrules[_data. term] = _data. new_var
395+ push! (combinations, _data. inv_term)
396+ push! (new_dvs, _data. new_var)
397+ else
398+ push! (combinations, [x])
399+ push! (new_dvs, x)
400+ end
401+ end
402+ all_solutions = collect .(collect (Iterators. product (combinations... )))
279403
404+ denoms = []
405+ eqs2 = map (eqs) do eq
406+ t = eq. rhs - eq. lhs
407+ t = Symbolics. fixpoint_sub (t, subrules; maxiters = length (dvs))
408+ # the substituted variable occurs outside the substituted term
409+ poly_and_nonpoly = map (dvs) do x
410+ haskey (var_to_nonpoly, x) && occursin (x, t)
411+ end
412+ if any (poly_and_nonpoly)
413+ throw (NotPolynomialError (
414+ VariablesAsPolyAndNonPoly (dvs[poly_and_nonpoly]), eqs, polydata))
415+ end
416+
417+ num, den = handle_rational_polynomials (t, new_dvs)
280418 # make factors different elements, otherwise the nonzero factors artificially
281419 # inflate the error of the zero factor.
282420 if iscall (den) && operation (den) == *
@@ -292,16 +430,19 @@ function MTK.HomotopyContinuationProblem(
292430 end
293431
294432 sys2 = MTK. @set sys. eqs = eqs2
433+ MTK. @set! sys2. unknowns = new_dvs
295434 # remove observed equations to avoid adding them in codegen
296435 MTK. @set! sys2. observed = Equation[]
297436 MTK. @set! sys2. substitutions = nothing
298437
299- nlfn, u0, p = MTK. process_SciMLProblem (NonlinearFunction{true }, sys2, u0map, parammap;
300- jac = true , eval_expression, eval_module)
438+ _, u0, p = MTK. process_SciMLProblem (
439+ MTK. EmptySciMLFunction, sys, u0map, parammap; eval_expression, eval_module)
440+ nlfn = NonlinearFunction {true} (sys2; jac = true , eval_expression, eval_module)
301441
302- denominator = MTK. build_explicit_observed_function (sys, denoms)
442+ denominator = MTK. build_explicit_observed_function (sys2, denoms)
443+ unpack_solution = MTK. build_explicit_observed_function (sys2, all_solutions)
303444
304- hvars = symbolics_to_hc .(dvs )
445+ hvars = symbolics_to_hc .(new_dvs )
305446 mtkhsys = MTKHomotopySystem (nlfn. f, p, nlfn. jac, hvars, length (eqs))
306447
307448 obsfn = MTK. ObservedFunctionCache (sys; eval_expression, eval_module)
@@ -319,7 +460,7 @@ function MTK.HomotopyContinuationProblem(
319460 solver_and_starts = HomotopyContinuation. solver_startsolutions (mtkhsys; kwargs... )
320461 end
321462 return MTK. HomotopyContinuationProblem (
322- u0, mtkhsys, denominator, sys, obsfn, solver_and_starts)
463+ u0, mtkhsys, denominator, sys, obsfn, solver_and_starts, unpack_solution )
323464end
324465
325466"""
@@ -353,25 +494,35 @@ function CommonSolve.solve(prob::MTK.HomotopyContinuationProblem,
353494 if isempty (realsols)
354495 u = state_values (prob)
355496 retcode = SciMLBase. ReturnCode. ConvergenceFailure
497+ resid = prob. homotopy_continuation_system (u)
356498 else
357499 T = eltype (state_values (prob))
358- distance, idx = findmin (realsols) do result
500+ distance = T (Inf )
501+ u = state_values (prob)
502+ resid = nothing
503+ for result in realsols
359504 if any (<= (denominator_abstol),
360505 prob. denominator (real .(result. solution), parameter_values (prob)))
361- return T (Inf )
506+ continue
507+ end
508+ for truesol in prob. unpack_solution (result. solution, parameter_values (prob))
509+ dist = norm (truesol - state_values (prob))
510+ if dist < distance
511+ distance = dist
512+ u = T .(real .(truesol))
513+ resid = T .(real .(prob. homotopy_continuation_system (result. solution)))
514+ end
362515 end
363- norm (result. solution - state_values (prob))
364516 end
365517 # all roots cause denominator to be zero
366518 if isinf (distance)
367519 u = state_values (prob)
520+ resid = prob. homotopy_continuation_system (u)
368521 retcode = SciMLBase. ReturnCode. Infeasible
369522 else
370- u = real .(realsols[idx]. solution)
371523 retcode = SciMLBase. ReturnCode. Success
372524 end
373525 end
374- resid = prob. homotopy_continuation_system (u)
375526
376527 return SciMLBase. build_solution (
377528 prob, :HomotopyContinuation , u, resid; retcode, original = sol)
0 commit comments