Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion lib/NonlinearSolveHomotopyContinuation/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ DocStringExtensions = "0.9.3"
Enzyme = "0.13"
HomotopyContinuation = "2.12.0"
LinearAlgebra = "1.10"
NaNMath = "1.1"
NonlinearSolve = "4"
NonlinearSolveBase = "1.3.3"
SciMLBase = "2.72.2"
Expand All @@ -37,8 +38,9 @@ julia = "1.10"
[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Test", "NonlinearSolve", "Enzyme"]
test = ["Aqua", "Test", "NonlinearSolve", "Enzyme", "NaNMath"]
15 changes: 13 additions & 2 deletions lib/NonlinearSolveHomotopyContinuation/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@ function CommonSolve.solve(prob::NonlinearProblem, alg::HomotopyContinuationJL{t
end
# unpack solutions and make them real
u = isscalar ? only(result.solution) : result.solution
append!(validsols, f.unpolynomialize(real.(u), p))
unpolysols = f.unpolynomialize(real.(u), p)
for sol in unpolysols
any(isnan, sol) && continue
push!(validsols, sol)
end
end

# if there are no valid solutions
Expand Down Expand Up @@ -137,10 +141,17 @@ function CommonSolve.solve(prob::NonlinearProblem, alg::HomotopyContinuationJL{f
T = eltype(u0)
validsols = f.unpolynomialize(realsol, p)
_, idx = findmin(validsols) do sol
norm(sol - u0_p)
any(isnan, sol) ? Inf : norm(sol - u0_p)
end

u = map(real, validsols[idx])

if any(isnan, u)
retcode = SciMLBase.ReturnCode.Infeasible
resid = NonlinearSolveBase.Utils.evaluate_f(prob, u0)
return SciMLBase.build_solution(prob, alg, u0, resid; retcode, original = orig_sol)
end

if u0 isa Number
u = only(u)
end
Expand Down
20 changes: 20 additions & 0 deletions lib/NonlinearSolveHomotopyContinuation/test/allroots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using NonlinearSolveHomotopyContinuation
using SciMLBase: NonlinearSolution
using ADTypes
using Enzyme
import NaNMath

alg = HomotopyContinuationJL{true}(; threading = false)

Expand Down Expand Up @@ -170,3 +171,22 @@ end
end
end
end

@testset "`NaN` unpolynomialize" begin
polynomialize = function (u, p)
return sin(u^2)
end
unpolynomialize = function (u, p)
return (-NaNMath.sqrt(NaNMath.asin(u)), NaNMath.sqrt(NaNMath.asin(u)))
end
rhs = function (u, p)
return u^2 + u - 1
end
prob = NonlinearProblem(
HomotopyNonlinearFunction(rhs; polynomialize, unpolynomialize), 1.0)
sol = solve(prob, alg)
@test sol isa EnsembleSolution
for nlsol in sol.u
@test !isnan(nlsol.u)
end
end
17 changes: 17 additions & 0 deletions lib/NonlinearSolveHomotopyContinuation/test/single_root.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using NonlinearSolve
using NonlinearSolveHomotopyContinuation
using SciMLBase: NonlinearSolution
import NaNMath

alg = HomotopyContinuationJL{false}(; threading = false)

Expand Down Expand Up @@ -146,3 +147,19 @@ end
end
end
end

@testset "`NaN` unpolynomialize" begin
polynomialize = function (u, p)
return sin(u^2)
end
unpolynomialize = function (u, p)
return NaN
end
rhs = function (u, p)
return u^2 + u - 1
end
prob = NonlinearProblem(
HomotopyNonlinearFunction(rhs; polynomialize, unpolynomialize), 1.0)
sol = solve(prob, alg)
@test !SciMLBase.successful_retcode(sol)
end
Loading