Skip to content

Commit 8cf9899

Browse files
committed
fix: auto-set autodiff for ForwardDiff if trying to propagate Duals
1 parent 244d7bb commit 8cf9899

File tree

4 files changed

+10
-5
lines changed

4 files changed

+10
-5
lines changed

lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ function CommonSolve.solve(
5454
alg::AbstractSimpleNonlinearSolveAlgorithm,
5555
args...;
5656
kwargs...) where {T, V, P, iip}
57+
if hasfield(typeof(alg), :autodiff) && alg.autodiff === nothing
58+
@reset alg.autodiff = AutoForwardDiff()
59+
end
5760
prob = convert(ImmutableNonlinearProblem, prob)
5861
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
5962
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
@@ -68,6 +71,9 @@ function CommonSolve.solve(
6871
alg::AbstractSimpleNonlinearSolveAlgorithm,
6972
args...;
7073
kwargs...) where {T, V, P, iip}
74+
if hasfield(typeof(alg), :autodiff) && alg.autodiff === nothing
75+
@reset alg.autodiff = AutoForwardDiff()
76+
end
7177
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
7278
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
7379
return SciMLBase.build_solution(

lib/SimpleNonlinearSolve/test/core/forward_diff_tests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
jacobian_f(u, p::Number) = one.(u) .* (1 / (2 * p))
1515
jacobian_f(u, p::AbstractArray) = diagm(vec(@. 1 / (2 * p)))
1616

17-
@testset for alg in (
17+
@testset "#(nameof(typeof(alg)))" for alg in (
1818
SimpleNewtonRaphson(),
1919
SimpleTrustRegion(),
2020
SimpleTrustRegion(; nlsolve_update_rule = Val(true)),
@@ -118,7 +118,7 @@ end
118118

119119
θ_init = θ_true .+ 0.1
120120

121-
@testset for alg in (
121+
for alg in (
122122
SimpleGaussNewton(),
123123
SimpleGaussNewton(; autodiff = AutoForwardDiff()),
124124
SimpleGaussNewton(; autodiff = AutoFiniteDiff()),

lib/SimpleNonlinearSolve/test/core/qa_tests.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ end
1212
import ReverseDiff, Tracker, StaticArrays, Zygote
1313
using ExplicitImports, SimpleNonlinearSolve
1414

15-
@test check_no_implicit_imports(
16-
SimpleNonlinearSolve; skip = (Base, Core, SciMLBase)) === nothing
15+
@test check_no_implicit_imports(SimpleNonlinearSolve; skip = (Base, Core)) === nothing
1716
@test check_no_stale_explicit_imports(SimpleNonlinearSolve) === nothing
1817
@test check_all_qualified_accesses_via_owners(SimpleNonlinearSolve) === nothing
1918
end

lib/SimpleNonlinearSolve/test/core/rootfind_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
end
3737

3838
@testitem "First Order Methods" setup=[RootfindTestSnippet] tags=[:core] begin
39-
@testset for alg in (
39+
for alg in (
4040
SimpleNewtonRaphson,
4141
SimpleTrustRegion,
4242
(; kwargs...) -> SimpleTrustRegion(; kwargs..., nlsolve_update_rule = Val(true))

0 commit comments

Comments
 (0)