5454
5555PolynomialData () =  PolynomialData (BasicSymbolic[], NonPolynomialReason. T[], false )
5656
57+ abstract type  PolynomialTransformationError <:  Exception  end 
58+ 
59+ struct  MultivarTerm <:  PolynomialTransformationError 
60+     term:: Any 
61+     vars:: Any 
62+ end 
63+ 
64+ function  Base. showerror (io:: IO , err:: MultivarTerm )
65+     println (io,
66+         " Cannot convert system to polynomial: Found term $(err. term)  which is a function of multiple unknowns $(err. vars) ."  )
67+ end 
68+ 
69+ struct  MultipleTermsOfSameVar <:  PolynomialTransformationError 
70+     terms:: Any 
71+     var:: Any 
72+ end 
73+ 
74+ function  Base. showerror (io:: IO , err:: MultipleTermsOfSameVar )
75+     println (io,
76+         " Cannot convert system to polynomial: Found multiple non-polynomial terms $(err. terms)  involving the same unknown $(err. var) ."  )
77+ end 
78+ 
79+ struct  SymbolicSolveFailure <:  PolynomialTransformationError 
80+     term:: Any 
81+     var:: Any 
82+ end 
83+ 
84+ function  Base. showerror (io:: IO , err:: SymbolicSolveFailure )
85+     println (io,
86+         " Cannot convert system to polynomial: Unable to symbolically solve $(err. term)  for $(err. var) ."  )
87+ end 
88+ 
89+ struct  NemoNotLoaded <:  PolynomialTransformationError  end 
90+ 
91+ function  Base. showerror (io:: IO , err:: NemoNotLoaded )
92+     println (io,
93+         " ModelingToolkit may be able to solve this system as a polynomial system if `Nemo` is loaded. Run `import Nemo` and try again."  )
94+ end 
95+ 
96+ struct  VariablesAsPolyAndNonPoly <:  PolynomialTransformationError 
97+     vars:: Any 
98+ end 
99+ 
100+ function  Base. showerror (io:: IO , err:: VariablesAsPolyAndNonPoly )
101+     println (io,
102+         " Cannot convert convert system to polynomial: Variables $(err. vars)  occur in both polynomial and non-polynomial terms in the system."  )
103+ end 
104+ 
57105struct  NotPolynomialError <:  Exception 
58-     eq:: Equation 
59-     data:: PolynomialData 
106+     transformation_err:: Union{PolynomialTransformationError, Nothing} 
107+     eq:: Vector{Equation} 
108+     data:: Vector{PolynomialData} 
60109end 
61110
62111function  Base. showerror (io:: IO , err:: NotPolynomialError )
63-     println (io,
64-         " Equation $(err. eq)  is not a polynomial in the unknowns for the following reasons:"  )
65-     for  (term, reason) in  zip (err. data. non_polynomial_terms, err. data. reasons)
66-         println (io, display_reason (reason, term))
112+     if  err. transformation_err != =  nothing 
113+         Base. showerror (io, err. transformation_err)
114+     end 
115+     for  (eq, data) in  zip (err. eq, err. data)
116+         if  isempty (data. non_polynomial_terms)
117+             continue 
118+         end 
119+         println (io,
120+             " Equation $(eq)  is not a polynomial in the unknowns for the following reasons:"  )
121+         for  (term, reason) in  zip (data. non_polynomial_terms, data. reasons)
122+             println (io, display_reason (reason, term))
123+         end 
67124    end 
68125end 
69126
@@ -86,7 +143,9 @@ function process_polynomial!(data::PolynomialData, x, wrt)
86143    any (isequal (x), wrt) &&  return  true 
87144
88145    if  operation (x) in  (* , + , - , / )
89-         return  all (y ->  is_polynomial! (data, y, wrt), arguments (x))
146+         #  `map` because `all` will early exit, but we want to search
147+         #  through everything to get all the non-polynomial terms
148+         return  all (map (y ->  is_polynomial! (data, y, wrt), arguments (x)))
90149    end 
91150    if  operation (x) ==  (^ )
92151        b, p =  arguments (x)
@@ -105,10 +164,6 @@ function process_polynomial!(data::PolynomialData, x, wrt)
105164            push! (data. reasons, NonPolynomialReason. ExponentContainsUnknowns)
106165        end 
107166        base_polynomial =  is_polynomial! (data, b, wrt)
108-         if  ! base_polynomial
109-             push! (data. non_polynomial_terms, x)
110-             push! (data. reasons, NonPolynomialReason. BaseNotPolynomial)
111-         end 
112167        return  base_polynomial &&  ! exponent_has_unknowns &&  is_pow_integer
113168    end 
114169    push! (data. non_polynomial_terms, x)
234289
235290SymbolicIndexingInterface. parameter_values (s:: MTKHomotopySystem ) =  s. p
236291
292+ struct  PolynomialTransformationData
293+     new_var:: BasicSymbolic 
294+     term:: BasicSymbolic 
295+     inv_term:: Vector 
296+ end 
297+ 
237298""" 
238299    $(TYPEDSIGNATURES)  
239300
@@ -265,18 +326,95 @@ function MTK.HomotopyContinuationProblem(
265326    #  CSE/hashconsing.
266327    eqs =  full_equations (sys)
267328
268-     denoms =  []
269-     has_parametric_exponents =  false 
270-     eqs2 =  map (eqs) do  eq
329+     polydata =  map (eqs) do  eq
271330        data =  PolynomialData ()
272331        process_polynomial! (data, eq. lhs, dvs)
273332        process_polynomial! (data, eq. rhs, dvs)
274-         has_parametric_exponents |=  data. has_parametric_exponent
275-         if  ! isempty (data. non_polynomial_terms)
276-             throw (NotPolynomialError (eq, data))
333+         data
334+     end 
335+ 
336+     has_parametric_exponents =  any (d ->  d. has_parametric_exponent, polydata)
337+ 
338+     all_non_poly_terms =  mapreduce (d ->  d. non_polynomial_terms, vcat, polydata)
339+     unique! (all_non_poly_terms)
340+ 
341+     var_to_nonpoly =  Dict {BasicSymbolic, PolynomialTransformationData} ()
342+ 
343+     is_poly =  true 
344+     transformation_err =  nothing 
345+     for  t in  all_non_poly_terms
346+         #  if the term involves multiple unknowns, we can't invert it
347+         dvs_in_term =  map (x ->  occursin (x, t), dvs)
348+         if  count (dvs_in_term) >  1 
349+             transformation_err =  MultivarTerm (t, dvs[dvs_in_term])
350+             is_poly =  false 
351+             break 
352+         end 
353+         #  we already have a substitution solving for `var`
354+         var =  dvs[findfirst (dvs_in_term)]
355+         if  haskey (var_to_nonpoly, var) &&  ! isequal (var_to_nonpoly[var]. term, t)
356+             transformation_err =  MultipleTermsOfSameVar ([t, var_to_nonpoly[var]. term], var)
357+             is_poly =  false 
358+             break 
359+         end 
360+         #  we want to solve `term - new_var` for `var`
361+         new_var =  gensym (Symbol (var))
362+         new_var =  unwrap (only (@variables  $ new_var))
363+         invterm =  Symbolics. ia_solve (
364+             t -  new_var, var; complex_roots =  false , periodic_roots =  false , warns =  false )
365+         #  if we can't invert it, quit
366+         if  invterm ===  nothing  ||  isempty (invterm)
367+             transformation_err =  SymbolicSolveFailure (t, var)
368+             is_poly =  false 
369+             break 
370+         end 
371+         #  `ia_solve` returns lazy terms i.e. `asin(1.0)` instead of `pi/2`
372+         #  this just evaluates the constant expressions
373+         invterm =  Symbolics. substitute .(invterm, (Dict (),))
374+         #  RootsOf implies Symbolics couldn't solve the inner polynomial because
375+         #  `Nemo` wasn't loaded.
376+         if  any (x ->  MTK. iscall (x) &&  MTK. operation (x) ==  Symbolics. RootsOf, invterm)
377+             transformation_err =  NemoNotLoaded ()
378+             is_poly =  false 
379+             break 
277380        end 
278-         num, den =  handle_rational_polynomials (eq. rhs -  eq. lhs, dvs)
381+         var_to_nonpoly[var] =  PolynomialTransformationData (new_var, t, invterm)
382+     end 
383+ 
384+     if  ! is_poly
385+         throw (NotPolynomialError (transformation_err, eqs, polydata))
386+     end 
387+ 
388+     subrules =  Dict ()
389+     combinations =  Vector[]
390+     new_dvs =  []
391+     for  x in  dvs
392+         if  haskey (var_to_nonpoly, x)
393+             _data =  var_to_nonpoly[x]
394+             subrules[_data. term] =  _data. new_var
395+             push! (combinations, _data. inv_term)
396+             push! (new_dvs, _data. new_var)
397+         else 
398+             push! (combinations, [x])
399+             push! (new_dvs, x)
400+         end 
401+     end 
402+     all_solutions =  collect .(collect (Iterators. product (combinations... )))
279403
404+     denoms =  []
405+     eqs2 =  map (eqs) do  eq
406+         t =  eq. rhs -  eq. lhs
407+         t =  Symbolics. fixpoint_sub (t, subrules; maxiters =  length (dvs))
408+         #  the substituted variable occurs outside the substituted term
409+         poly_and_nonpoly =  map (dvs) do  x
410+             haskey (var_to_nonpoly, x) &&  occursin (x, t)
411+         end 
412+         if  any (poly_and_nonpoly)
413+             throw (NotPolynomialError (
414+                 VariablesAsPolyAndNonPoly (dvs[poly_and_nonpoly]), eqs, polydata))
415+         end 
416+ 
417+         num, den =  handle_rational_polynomials (t, new_dvs)
280418        #  make factors different elements, otherwise the nonzero factors artificially
281419        #  inflate the error of the zero factor.
282420        if  iscall (den) &&  operation (den) ==  * 
@@ -292,16 +430,19 @@ function MTK.HomotopyContinuationProblem(
292430    end 
293431
294432    sys2 =  MTK. @set  sys. eqs =  eqs2
433+     MTK. @set!  sys2. unknowns =  new_dvs
295434    #  remove observed equations to avoid adding them in codegen
296435    MTK. @set!  sys2. observed =  Equation[]
297436    MTK. @set!  sys2. substitutions =  nothing 
298437
299-     nlfn, u0, p =  MTK. process_SciMLProblem (NonlinearFunction{true }, sys2, u0map, parammap;
300-         jac =  true , eval_expression, eval_module)
438+     _, u0, p =  MTK. process_SciMLProblem (
439+         MTK. EmptySciMLFunction, sys, u0map, parammap; eval_expression, eval_module)
440+     nlfn =  NonlinearFunction {true} (sys2; jac =  true , eval_expression, eval_module)
301441
302-     denominator =  MTK. build_explicit_observed_function (sys, denoms)
442+     denominator =  MTK. build_explicit_observed_function (sys2, denoms)
443+     unpack_solution =  MTK. build_explicit_observed_function (sys2, all_solutions)
303444
304-     hvars =  symbolics_to_hc .(dvs )
445+     hvars =  symbolics_to_hc .(new_dvs )
305446    mtkhsys =  MTKHomotopySystem (nlfn. f, p, nlfn. jac, hvars, length (eqs))
306447
307448    obsfn =  MTK. ObservedFunctionCache (sys; eval_expression, eval_module)
@@ -319,7 +460,7 @@ function MTK.HomotopyContinuationProblem(
319460        solver_and_starts =  HomotopyContinuation. solver_startsolutions (mtkhsys; kwargs... )
320461    end 
321462    return  MTK. HomotopyContinuationProblem (
322-         u0, mtkhsys, denominator, sys, obsfn, solver_and_starts)
463+         u0, mtkhsys, denominator, sys, obsfn, solver_and_starts, unpack_solution )
323464end 
324465
325466""" 
@@ -353,25 +494,35 @@ function CommonSolve.solve(prob::MTK.HomotopyContinuationProblem,
353494    if  isempty (realsols)
354495        u =  state_values (prob)
355496        retcode =  SciMLBase. ReturnCode. ConvergenceFailure
497+         resid =  prob. homotopy_continuation_system (u)
356498    else 
357499        T =  eltype (state_values (prob))
358-         distance, idx =  findmin (realsols) do  result
500+         distance =  T (Inf )
501+         u =  state_values (prob)
502+         resid =  nothing 
503+         for  result in  realsols
359504            if  any (<= (denominator_abstol),
360505                prob. denominator (real .(result. solution), parameter_values (prob)))
361-                 return  T (Inf )
506+                 continue 
507+             end 
508+             for  truesol in  prob. unpack_solution (result. solution, parameter_values (prob))
509+                 dist =  norm (truesol -  state_values (prob))
510+                 if  dist <  distance
511+                     distance =  dist
512+                     u =  T .(real .(truesol))
513+                     resid =  T .(real .(prob. homotopy_continuation_system (result. solution)))
514+                 end 
362515            end 
363-             norm (result. solution -  state_values (prob))
364516        end 
365517        #  all roots cause denominator to be zero
366518        if  isinf (distance)
367519            u =  state_values (prob)
520+             resid =  prob. homotopy_continuation_system (u)
368521            retcode =  SciMLBase. ReturnCode. Infeasible
369522        else 
370-             u =  real .(realsols[idx]. solution)
371523            retcode =  SciMLBase. ReturnCode. Success
372524        end 
373525    end 
374-     resid =  prob. homotopy_continuation_system (u)
375526
376527    return  SciMLBase. build_solution (
377528        prob, :HomotopyContinuation , u, resid; retcode, original =  sol)
0 commit comments