diff --git a/lib/NonlinearSolveHomotopyContinuation/src/NonlinearSolveHomotopyContinuation.jl b/lib/NonlinearSolveHomotopyContinuation/src/NonlinearSolveHomotopyContinuation.jl index 26639e4f0..e679727be 100644 --- a/lib/NonlinearSolveHomotopyContinuation/src/NonlinearSolveHomotopyContinuation.jl +++ b/lib/NonlinearSolveHomotopyContinuation/src/NonlinearSolveHomotopyContinuation.jl @@ -17,13 +17,15 @@ using ConcreteStructs: @concrete export HomotopyContinuationJL, HomotopyNonlinearFunction """ - HomotopyContinuationJL{AllRoots}(; autodiff = true, kwargs...) + HomotopyContinuationJL{AllRoots, ComplexRoots}(; autodiff = true, kwargs...) HomotopyContinuationJL(; kwargs...) = HomotopyContinuationJL{false}(; kwargs...) This algorithm is an interface to `HomotopyContinuation.jl`. It is only valid for fully determined polynomial systems. The `AllRoots` type parameter can be `true` or `false` and controls whether the solver will find all roots of the polynomial or a single root close to the initial guess provided to the `NonlinearProblem`. +The `ComplexRoots` type parameter can be `Val{true}` or `Val{false}` (default) and +controls whether complex roots are returned or filtered to only real roots. The polynomial function must allow complex numbers to be provided as the state. If `AllRoots` is `true`, the initial guess in the `NonlinearProblem` is ignored. @@ -36,6 +38,9 @@ depends on the initial guess provided to the `NonlinearProblem` being solved. Th does not require that the polynomial function is traceable via HomotopyContinuation.jl's symbolic variables. +If `ComplexRoots` is `Val{true}`, complex roots will be returned. If `Val{false}` (default), +only real roots will be returned. + HomotopyContinuation.jl requires the jacobian of the system. In case a jacobian function is provided, it will be used. Otherwise, the `autodiff` keyword argument controls the autodiff method used to compute the jacobian. A value of `true` refers to @@ -45,23 +50,27 @@ specified using ADTypes.jl. HomotopyContinuation.jl requires the taylor series of the polynomial system for the single root method. This is automatically computed using TaylorSeries.jl. """ -@concrete struct HomotopyContinuationJL{AllRoots} <: +@concrete struct HomotopyContinuationJL{AllRoots, ComplexRoots} <: NonlinearSolveBase.AbstractNonlinearSolveAlgorithm autodiff kwargs end -function HomotopyContinuationJL{AllRoots}(; autodiff = true, kwargs...) where {AllRoots} +function HomotopyContinuationJL{AllRoots, ComplexRoots}(; autodiff = true, kwargs...) where {AllRoots, ComplexRoots} if autodiff isa Bool autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff() end - HomotopyContinuationJL{AllRoots}(autodiff, kwargs) + HomotopyContinuationJL{AllRoots, ComplexRoots}(autodiff, kwargs) +end + +function HomotopyContinuationJL{AllRoots}(; autodiff = true, kwargs...) where {AllRoots} + HomotopyContinuationJL{AllRoots, Val{false}}(; autodiff, kwargs...) end HomotopyContinuationJL(; kwargs...) = HomotopyContinuationJL{false}(; kwargs...) -function HomotopyContinuationJL(alg::HomotopyContinuationJL{R}; kwargs...) where {R} - HomotopyContinuationJL{R}(; autodiff = alg.autodiff, alg.kwargs..., kwargs...) +function HomotopyContinuationJL(alg::HomotopyContinuationJL{R, C}; kwargs...) where {R, C} + HomotopyContinuationJL{R, C}(; autodiff = alg.autodiff, alg.kwargs..., kwargs...) end include("interface_types.jl") diff --git a/lib/NonlinearSolveHomotopyContinuation/src/solve.jl b/lib/NonlinearSolveHomotopyContinuation/src/solve.jl index dee0bd921..f3766ae31 100644 --- a/lib/NonlinearSolveHomotopyContinuation/src/solve.jl +++ b/lib/NonlinearSolveHomotopyContinuation/src/solve.jl @@ -54,8 +54,8 @@ function homotopy_continuation_preprocessing( return f, hcsys end -function CommonSolve.solve(prob::NonlinearProblem, alg::HomotopyContinuationJL{true}; - denominator_abstol = 1e-7, kwargs...) +function CommonSolve.solve(prob::NonlinearProblem, alg::HomotopyContinuationJL{true, ComplexRoots}; + denominator_abstol = 1e-7, kwargs...) where {ComplexRoots} f, hcsys = homotopy_continuation_preprocessing(prob, alg) u0 = state_values(prob) @@ -63,28 +63,30 @@ function CommonSolve.solve(prob::NonlinearProblem, alg::HomotopyContinuationJL{t isscalar = u0 isa Number orig_sol = HC.solve(hcsys; alg.kwargs..., kwargs...) - realsols = HC.results(orig_sol; only_real = true) - # no real solutions + only_real_roots = ComplexRoots === Val{false} + realsols = HC.results(orig_sol; only_real = only_real_roots) + # no solutions if isempty(realsols) retcode = SciMLBase.ReturnCode.ConvergenceFailure resid = NonlinearSolveBase.Utils.evaluate_f(prob, u0) nlsol = SciMLBase.build_solution(prob, alg, u0, resid; retcode, original = orig_sol) return SciMLBase.EnsembleSolution([nlsol], 0.0, false, nothing) end - T = eltype(u0) + T = ComplexRoots === Val{false} ? eltype(u0) : promote_type(eltype(u0), Complex{real(eltype(u0))}) validsols = isscalar ? T[] : Vector{T}[] for result in realsols # ignore ones which make the denominator zero - real_u = real.(result.solution) + test_u = ComplexRoots === Val{false} ? real.(result.solution) : result.solution if isscalar - real_u = only(real_u) + test_u = only(test_u) end - if any(<=(denominator_abstol) ∘ abs, f.denominator(real_u, p)) + if any(<=(denominator_abstol) ∘ abs, f.denominator(test_u, p)) continue end - # unpack solutions and make them real + # unpack solutions u = isscalar ? only(result.solution) : result.solution - unpolysols = f.unpolynomialize(real.(u), p) + u_for_unpolynom = ComplexRoots === Val{false} ? real.(u) : u + unpolysols = f.unpolynomialize(u_for_unpolynom, p) for sol in unpolysols any(isnan, sol) && continue push!(validsols, sol) @@ -107,8 +109,8 @@ function CommonSolve.solve(prob::NonlinearProblem, alg::HomotopyContinuationJL{t return SciMLBase.EnsembleSolution(nlsols, 0.0, true, nothing) end -function CommonSolve.solve(prob::NonlinearProblem, alg::HomotopyContinuationJL{false}; - denominator_abstol = 1e-7, kwargs...) +function CommonSolve.solve(prob::NonlinearProblem, alg::HomotopyContinuationJL{false, ComplexRoots}; + denominator_abstol = 1e-7, kwargs...) where {ComplexRoots} f, hcsys = homotopy_continuation_preprocessing(prob, alg) u0 = state_values(prob) @@ -120,14 +122,15 @@ function CommonSolve.solve(prob::NonlinearProblem, alg::HomotopyContinuationJL{f homotopy = GuessHomotopy(hcsys, fu0) orig_sol = HC.solve( homotopy, u0_p isa Number ? [[u0_p]] : [u0_p]; alg.kwargs..., kwargs...) - realsols = map(res -> res.solution, HC.results(orig_sol; only_real = true)) + only_real_roots = ComplexRoots === Val{false} + realsols = map(res -> res.solution, HC.results(orig_sol; only_real = only_real_roots)) if u0 isa Number realsols = map(only, realsols) end - # no real solutions or infeasible solution + # no solutions or infeasible solution if isempty(realsols) || - any(<=(denominator_abstol), map(abs, f.denominator(real.(only(realsols)), p))) + any(<=(denominator_abstol), map(abs, f.denominator(ComplexRoots === Val{false} ? real.(only(realsols)) : only(realsols), p))) retcode = if isempty(realsols) SciMLBase.ReturnCode.ConvergenceFailure else @@ -137,14 +140,14 @@ function CommonSolve.solve(prob::NonlinearProblem, alg::HomotopyContinuationJL{f return SciMLBase.build_solution(prob, alg, u0, resid; retcode, original = orig_sol) end - realsol = real(only(realsols)) + realsol = ComplexRoots === Val{false} ? real(only(realsols)) : only(realsols) T = eltype(u0) validsols = f.unpolynomialize(realsol, p) _, idx = findmin(validsols) do sol any(isnan, sol) ? Inf : norm(sol - u0_p) end - u = map(real, validsols[idx]) + u = ComplexRoots === Val{false} ? map(real, validsols[idx]) : validsols[idx] if any(isnan, u) retcode = SciMLBase.ReturnCode.Infeasible diff --git a/lib/NonlinearSolveHomotopyContinuation/test/complex_roots_test.jl b/lib/NonlinearSolveHomotopyContinuation/test/complex_roots_test.jl new file mode 100644 index 000000000..b43d734b2 --- /dev/null +++ b/lib/NonlinearSolveHomotopyContinuation/test/complex_roots_test.jl @@ -0,0 +1,94 @@ +using NonlinearSolve +using NonlinearSolveHomotopyContinuation +using SciMLBase: NonlinearSolution + +# Test complex roots for scalar polynomial +@testset "Complex roots - scalar" begin + # Polynomial: u^2 + 1 = 0, roots should be ±i + rhs = function (u, p) + return u * u + 1 + end + + prob = NonlinearProblem(rhs, 1.0 + 0.0im) + + # Test with complex roots enabled + alg_complex = HomotopyContinuationJL{true, Val{true}}(; threading = false) + sol_complex = solve(prob, alg_complex) + + @test sol_complex isa EnsembleSolution + @test sol_complex.converged + @test length(sol_complex) == 2 + + # Sort solutions by imaginary part + solutions = [s.u for s in sol_complex.u] + sort!(solutions; by = imag) + + @test solutions[1] ≈ -1im atol=1e-10 + @test solutions[2] ≈ 1im atol=1e-10 + + # Test with complex roots disabled (should find no real solutions) + alg_real = HomotopyContinuationJL{true, Val{false}}(; threading = false) + sol_real = solve(prob, alg_real) + + @test !sol_real.converged + @test length(sol_real) == 1 + @test sol_real.u[1].retcode == SciMLBase.ReturnCode.ConvergenceFailure +end + +# Test complex roots for vector polynomial +@testset "Complex roots - vector" begin + # System: u[1]^2 + 1 = 0, u[2]^2 + 4 = 0 + # Roots should be [±i, ±2i] + rhs = function (u, p) + return [u[1]^2 + 1, u[2]^2 + 4] + end + + prob = NonlinearProblem(rhs, [1.0 + 0.0im, 1.0 + 0.0im]) + + # Test with complex roots enabled + alg_complex = HomotopyContinuationJL{true, Val{true}}(; threading = false) + sol_complex = solve(prob, alg_complex) + + @test sol_complex isa EnsembleSolution + @test sol_complex.converged + @test length(sol_complex) == 4 + + # Verify all solutions are approximately correct + for s in sol_complex.u + u = s.u + @test abs(u[1]^2 + 1) < 1e-10 + @test abs(u[2]^2 + 4) < 1e-10 + end + + # Test with complex roots disabled (should find no real solutions) + alg_real = HomotopyContinuationJL{true, Val{false}}(; threading = false) + sol_real = solve(prob, alg_real) + + @test !sol_real.converged + @test length(sol_real) == 1 + @test sol_real.u[1].retcode == SciMLBase.ReturnCode.ConvergenceFailure +end + +# Test single root method with complex roots +@testset "Complex roots - single root" begin + # Polynomial: u^2 + 1 = 0 + rhs = function (u, p) + return u * u + 1 + end + + prob = NonlinearProblem(rhs, 1.0 + 0.0im) + + # Test with complex roots enabled + alg_complex = HomotopyContinuationJL{false, Val{true}}(; threading = false) + sol_complex = solve(prob, alg_complex) + + @test sol_complex isa NonlinearSolution + @test SciMLBase.successful_retcode(sol_complex) + @test abs(sol_complex.u^2 + 1) < 1e-10 + + # Test with complex roots disabled + alg_real = HomotopyContinuationJL{false, Val{false}}(; threading = false) + sol_real = solve(prob, alg_real) + + @test !SciMLBase.successful_retcode(sol_real) +end \ No newline at end of file diff --git a/lib/NonlinearSolveHomotopyContinuation/test/runtests.jl b/lib/NonlinearSolveHomotopyContinuation/test/runtests.jl index 7a4a89ed8..16726d41b 100644 --- a/lib/NonlinearSolveHomotopyContinuation/test/runtests.jl +++ b/lib/NonlinearSolveHomotopyContinuation/test/runtests.jl @@ -14,4 +14,7 @@ using Aqua @testset "Single Root" begin include("single_root.jl") end + @testset "Complex Roots" begin + include("complex_roots_test.jl") + end end