Skip to content

Commit 4c28213

Browse files
committed
fix: exotic types
1 parent 4d9c30e commit 4c28213

11 files changed

+71
-21
lines changed

lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using SimpleNonlinearSolve: SimpleNonlinearSolve, simplenonlinearsolve_solve_up,
88
solve_adjoint
99

1010
function ChainRulesCore.rrule(::typeof(simplenonlinearsolve_solve_up),
11-
prob::Union{InternalNonlinearProblem, NonlinearLeastSquaresProblem},
11+
prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
1212
sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...)
1313
out, ∇internal = solve_adjoint(
1414
prob, sensealg, u0, p, ChainRulesOriginator(), alg, args...; kwargs...)

lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@ using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal
66
using SciMLBase: ReverseDiffOriginator, NonlinearLeastSquaresProblem, remake
77

88
using SimpleNonlinearSolve: SimpleNonlinearSolve, solve_adjoint
9+
import SimpleNonlinearSolve: simplenonlinearsolve_solve_up
910

10-
for pType in (InternalNonlinearProblem, NonlinearLeastSquaresProblem)
11+
for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem)
1112
aTypes = (TrackedArray, AbstractArray{<:TrackedReal}, Any)
1213
for (uT, pT) in collect(Iterators.product(aTypes, aTypes))[1:(end - 1)]
13-
@eval function SimpleNonlinearSolve.simplenonlinearsolve_solve_up(
14+
@eval function simplenonlinearsolve_solve_up(
1415
prob::$(pType), sensealg, u0::$(uT), u0_changed,
1516
p::$(pT), p_changed, alg, args...; kwargs...)
1617
return ReverseDiff.track(SimpleNonlinearSolve.simplenonlinearsolve_solve_up,
@@ -19,7 +20,7 @@ for pType in (InternalNonlinearProblem, NonlinearLeastSquaresProblem)
1920
end
2021
end
2122

22-
@eval ReverseDiff.@grad function SimpleNonlinearSolve.simplenonlinearsolve_solve_up(
23+
@eval ReverseDiff.@grad function simplenonlinearsolve_solve_up(
2324
tprob::$(pType), sensealg, tu0, u0_changed,
2425
tp, p_changed, alg, args...; kwargs...)
2526
u0, p = ReverseDiff.value(tu0), ReverseDiff.value(tp)

lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using Tracker: Tracker, TrackedArray, TrackedReal
77

88
using SimpleNonlinearSolve: SimpleNonlinearSolve, solve_adjoint
99

10-
for pType in (InternalNonlinearProblem, NonlinearLeastSquaresProblem)
10+
for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem)
1111
aTypes = (TrackedArray, AbstractArray{<:TrackedReal}, Any)
1212
for (uT, pT) in collect(Iterators.product(aTypes, aTypes))[1:(end - 1)]
1313
@eval function SimpleNonlinearSolve.simplenonlinearsolve_solve_up(

lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,20 @@ using SciMLBase: AbstractNonlinearAlgorithm, NonlinearProblem, ReturnCode
1414
using StaticArraysCore: StaticArray, SArray, SVector, MArray
1515

1616
# AD Dependencies
17-
using ADTypes: AbstractADType, AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff
17+
using ADTypes: AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff
1818
using DifferentiationInterface: DifferentiationInterface
1919
using FiniteDiff: FiniteDiff
2020
using ForwardDiff: ForwardDiff
2121

2222
using BracketingNonlinearSolve: Alefeld, Bisection, Brent, Falsi, ITP, Ridder
23-
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, get_tolerance,
24-
L2_NORM
23+
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, L2_NORM
2524

2625
const DI = DifferentiationInterface
2726

2827
abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
2928

29+
const safe_similar = NonlinearSolveBase.Utils.safe_similar
30+
3031
is_extension_loaded(::Val) = false
3132

3233
include("utils.jl")

lib/SimpleNonlinearSolve/src/halley.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ function SciMLBase.__solve(
4141

4242
strait = setindex_trait(x)
4343

44-
A = strait isa CanSetindex ? similar(x, length(x), length(x)) : x
45-
Aaᵢ = strait isa CanSetindex ? similar(x, length(x)) : x
46-
cᵢ = strait isa CanSetindex ? similar(x) : x
44+
A = strait isa CanSetindex ? safe_similar(x, length(x), length(x)) : x
45+
Aaᵢ = strait isa CanSetindex ? safe_similar(x, length(x)) : x
46+
cᵢ = strait isa CanSetindex ? safe_similar(x) : x
4747

4848
for _ in 1:maxiters
4949
fx, J, H = Utils.compute_jacobian_and_hessian(autodiff, prob, fx, x)

lib/SimpleNonlinearSolve/src/lbroyden.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ end
301301
return :(return SVector{$N, $T}(($(getcalls...))))
302302
end
303303

304-
lbroyden_threshold_cache(x, ::Val{threshold}) where {threshold} = similar(x, threshold)
304+
lbroyden_threshold_cache(x, ::Val{threshold}) where {threshold} = safe_similar(x, threshold)
305305
function lbroyden_threshold_cache(x::StaticArray, ::Val{threshold}) where {threshold}
306306
return zeros(MArray{Tuple{threshold}, eltype(x)})
307307
end
@@ -327,7 +327,7 @@ end
327327
end
328328
end
329329
function init_low_rank_jacobian(u, fu, ::Val{threshold}) where {threshold}
330-
Vᵀ = similar(u, threshold, length(u))
331-
U = similar(u, length(fu), threshold)
330+
Vᵀ = safe_similar(u, threshold, length(u))
331+
U = safe_similar(u, length(fu), threshold)
332332
return U, Vᵀ
333333
end

lib/SimpleNonlinearSolve/src/raphson.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ function SciMLBase.__solve(
4141
NonlinearSolveBase.select_jacobian_autodiff(prob, alg.autodiff)
4242

4343
@bb xo = similar(x)
44-
fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ? similar(fx) :
45-
nothing
44+
fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ?
45+
safe_similar(fx) : nothing
4646
jac_cache = Utils.prepare_jacobian(prob, autodiff, fx_cache, x)
4747
J = Utils.compute_jacobian!!(nothing, prob, autodiff, fx_cache, x, jac_cache)
4848

lib/SimpleNonlinearSolve/src/trust_region.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleTrustRegi
9393
norm_fx = L2_NORM(fx)
9494

9595
@bb xo = copy(x)
96-
fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ? similar(fx) :
97-
nothing
96+
fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ?
97+
safe_similar(fx) : nothing
9898
jac_cache = Utils.prepare_jacobian(prob, autodiff, fx_cache, x)
9999
J = Utils.compute_jacobian!!(nothing, prob, autodiff, fx_cache, x, jac_cache)
100100

lib/SimpleNonlinearSolve/src/utils.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
module Utils
22

3-
using ADTypes: AbstractADType, AutoForwardDiff, AutoFiniteDiff, AutoPolyesterForwardDiff
43
using ArrayInterface: ArrayInterface
54
using ConcreteStructs: @concrete
65
using DifferentiationInterface: DifferentiationInterface, Constant
@@ -164,7 +163,7 @@ function compute_jacobian!!(J, prob, autodiff, fx, x, extras)
164163
if J === nothing
165164
if extras isa AnalyticJacobian
166165
if SciMLBase.isinplace(prob.f)
167-
J = similar(fx, length(fx), length(x))
166+
J = safe_similar(fx, length(fx), length(x))
168167
prob.f.jac(J, x, prob.p)
169168
return J
170169
else
@@ -219,7 +218,7 @@ end
219218
function compute_jacobian_and_hessian(autodiff, prob, fx, x)
220219
if SciMLBase.isinplace(prob)
221220
jac_fn = @closure (u, p) -> begin
222-
du = similar(fx, promote_type(eltype(fx), eltype(u)))
221+
du = safe_similar(fx, promote_type(eltype(fx), eltype(u)))
223222
return DI.jacobian(prob.f, du, autodiff, u, Constant(p))
224223
end
225224
J, H = DI.value_and_jacobian(jac_fn, autodiff, x, Constant(prob.p))
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
@testitem "BigFloat Support" tags=[:core] begin
2+
using SimpleNonlinearSolve, LinearAlgebra
3+
4+
fn_iip = NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p)
5+
fn_oop = NonlinearFunction{false}((u, p) -> u .* u .- p)
6+
7+
u0 = BigFloat[1.0, 1.0, 1.0]
8+
prob_iip_bf = NonlinearProblem{true}(fn_iip, u0, BigFloat(2))
9+
prob_oop_bf = NonlinearProblem{false}(fn_oop, u0, BigFloat(2))
10+
11+
@testset "$(nameof(typeof(alg)))" for alg in (
12+
SimpleNewtonRaphson(),
13+
SimpleBroyden(),
14+
SimpleKlement(),
15+
SimpleDFSane(),
16+
SimpleTrustRegion(),
17+
SimpleLimitedMemoryBroyden(),
18+
SimpleHalley()
19+
)
20+
sol = solve(prob_oop_bf, alg)
21+
@test maximum(abs, sol.resid) < 1e-6
22+
@test SciMLBase.successful_retcode(sol.retcode)
23+
24+
alg isa SimpleHalley && continue
25+
26+
sol = solve(prob_iip_bf, alg)
27+
@test maximum(abs, sol.resid) < 1e-6
28+
@test SciMLBase.successful_retcode(sol.retcode)
29+
end
30+
end

0 commit comments

Comments
 (0)