Skip to content

Commit ee15d80

Browse files
committed
Fix most tests
1 parent cefe5b0 commit ee15d80

File tree

6 files changed

+31
-15
lines changed

6 files changed

+31
-15
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ LeastSquaresOptim = "0.8"
6060
LineSearches = "7"
6161
LinearAlgebra = "<0.0.1, 1"
6262
LinearSolve = "2.12"
63+
MaybeInplace = "0.1"
6364
NaNMath = "1"
6465
NonlinearProblemLibrary = "0.1"
6566
Pkg = "1"
@@ -71,7 +72,7 @@ Reexport = "0.2, 1"
7172
SafeTestsets = "0.1"
7273
SciMLBase = "2.9"
7374
SciMLOperators = "0.3"
74-
SimpleNonlinearSolve = "1" # FIXME: Don't update the version in this PR. Using it to test
75+
SimpleNonlinearSolve = "1"
7576
SparseArrays = "<0.0.1, 1"
7677
SparseDiffTools = "2.14"
7778
StaticArrays = "1"
@@ -98,6 +99,7 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
9899
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
99100
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
100101
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
102+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
101103
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
102104
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
103105
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

src/NonlinearSolve.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_work
1717
import ConcreteStructs: @concrete
1818
import EnumX: @enumx
1919
import FastBroadcast: @..
20+
import FiniteDiff
2021
import ForwardDiff
2122
import ForwardDiff: Dual
2223
import LinearSolve: ComposePreconditioner, InvPreconditioner, needs_concrete_A
@@ -56,7 +57,7 @@ function SciMLBase.reinit!(cache::AbstractNonlinearSolveCache{iip}, u0 = get_u(c
5657
cache.p = p
5758
if iip
5859
recursivecopy!(get_u(cache), u0)
59-
cache.f(cache.fu1, get_u(cache), p)
60+
cache.f(get_fu(cache), get_u(cache), p)
6061
else
6162
cache.u = __maybe_unaliased(u0, alias_u0)
6263
set_fu!(cache, cache.f(cache.u, p))
@@ -76,7 +77,7 @@ function SciMLBase.reinit!(cache::AbstractNonlinearSolveCache{iip}, u0 = get_u(c
7677

7778
if hasfield(typeof(cache), :ls_cache)
7879
# TODO: A more efficient way to do this
79-
cache.ls_cache = init_linesearch_cache(cache.prob, cache.alg.linesearch, cache.f,
80+
cache.ls_cache = init_linesearch_cache(cache.alg.linesearch, cache.f,
8081
get_u(cache), p, get_fu(cache), Val(iip))
8182
end
8283

src/dfsane.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ Computation, 75, 1429-1448.](https://www.researchgate.net/publication/220576479_
5555
end
5656

5757
@concrete mutable struct DFSaneCache{iip} <: AbstractNonlinearSolveCache{iip}
58+
f
5859
alg
5960
u
6061
u_cache
@@ -110,8 +111,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args.
110111
termination_condition)
111112
trace = init_nonlinearsolve_trace(alg, u, fu, nothing, du; kwargs...)
112113

113-
return DFSaneCache{iip}(alg, u, u_cache, u_cache_2, fu, fu_cache, du, history, f_norm,
114-
f_norm_0, alg.M, T(alg.σ_1), T(alg.σ_min), T(alg.σ_max), one(T), T(alg.γ),
114+
return DFSaneCache{iip}(prob.f, alg, u, u_cache, u_cache_2, fu, fu_cache, du, history,
115+
f_norm, f_norm_0, alg.M, T(alg.σ_1), T(alg.σ_min), T(alg.σ_max), one(T), T(alg.γ),
115116
T(alg.τ_min), T(alg.τ_max), alg.n_exp, prob.p, false, maxiters, internalnorm,
116117
ReturnCode.Default, abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache, trace)
117118
end

src/jacobian.jl

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,14 @@ function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
209209
end
210210

211211
# jvp fallback scalar
212-
__jacvec(args...; kwargs...) = JacVec(args...; kwargs...)
213-
function __jacvec(uf, u::Number; autodiff, kwargs...)
214-
@assert autodiff isa AutoForwardDiff "Only ForwardDiff is currently supported."
215-
return JVPScalar(uf, u, autodiff)
212+
function __jacvec(uf, u; autodiff, kwargs...)
213+
if !(autodiff isa AutoForwardDiff || autodiff isa AutoFiniteDiff)
214+
_ad = autodiff
215+
autodiff = ifelse(ForwardDiff.can_dual(eltype(u)), AutoForwardDiff(),
216+
AutoFiniteDiff())
217+
@warn "$(_ad) not supported for JacVec. Using $(autodiff) instead."
218+
end
219+
return u isa Number ? JVPScalar(uf, u, autodiff) : JacVec(uf, u; autodiff, kwargs...)
216220
end
217221

218222
@concrete mutable struct JVPScalar
@@ -221,10 +225,17 @@ end
221225
autodiff
222226
end
223227

224-
function Base.:*(jvp::JVPScalar, v)
225-
T = typeof(ForwardDiff.Tag(typeof(jvp.uf), typeof(jvp.u)))
226-
out = jvp.uf(ForwardDiff.Dual{T}(jvp.u, v))
227-
return ForwardDiff.extract_derivative(T, out)
228+
function Base.:*(jvp::JVPScalar, v::Number)
229+
if jvp.autodiff isa AutoForwardDiff
230+
T = typeof(ForwardDiff.Tag(typeof(jvp.uf), typeof(jvp.u)))
231+
out = jvp.uf(ForwardDiff.Dual{T}(jvp.u, v))
232+
return ForwardDiff.extract_derivative(T, out)
233+
elseif jvp.autodiff isa AutoFiniteDiff
234+
J = FiniteDiff.finite_difference_derivative(jvp.uf, jvp.u, jvp.autodiff.fdtype)
235+
return J * v
236+
else
237+
error("Only ForwardDiff & FiniteDiff is currently supported.")
238+
end
228239
end
229240

230241
# Generic Handling of Krylov Methods for Normal Form Linear Solves

src/trustRegion.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion,
255255
@bb u_cache_2 = similar(u)
256256
@bb u_cauchy = similar(u)
257257
@bb u_gauss_newton = similar(u)
258-
@bb J_cache = similar(J)
258+
J_cache = J isa SciMLOperators.AbstractSciMLOperator ||
259+
setindex_trait(J) === CannotSetindex() ? J : similar(J)
259260
@bb lr_mul_cache = similar(du)
260261

261262
loss_new = loss

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ function evaluate_f(prob::Union{NonlinearProblem{uType, iip},
178178
return fu
179179
end
180180

181-
function evaluate_f(f::F, u, p, ::Val{iip}; fu = nothing) where {F, iip <: Bool}
181+
function evaluate_f(f::F, u, p, ::Val{iip}; fu = nothing) where {F, iip}
182182
if iip
183183
f(fu, u, p)
184184
return fu

0 commit comments

Comments
 (0)