Skip to content

Commit 685319b

Browse files
committed
refactor: move JacobianCache into NonlinearSolveBase
1 parent 5549c42 commit 685319b

23 files changed

+374
-317
lines changed

.github/workflows/CI_NonlinearSolveBase.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ jobs:
4949
run: |
5050
import Pkg
5151
Pkg.Registry.update()
52+
# Install packages present in subdirectories
53+
dev_pks = Pkg.PackageSpec[]
54+
for path in ("lib/SciMLJacobianOperators",)
55+
push!(dev_pks, Pkg.PackageSpec(; path))
56+
end
57+
Pkg.develop(dev_pks)
5258
Pkg.instantiate()
5359
Pkg.test(; coverage=true)
5460
shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/NonlinearSolveBase {0}

lib/NonlinearSolveBase/Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
1919
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
2020
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2121
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
22+
SciMLJacobianOperators = "19f34311-ddf3-4b8b-af20-060888a46c0e"
2223
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
2324
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2425

@@ -27,12 +28,14 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
2728
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2829
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
2930
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
31+
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
3032

3133
[extensions]
3234
NonlinearSolveBaseDiffEqBaseExt = "DiffEqBase"
3335
NonlinearSolveBaseForwardDiffExt = "ForwardDiff"
3436
NonlinearSolveBaseLinearSolveExt = "LinearSolve"
3537
NonlinearSolveBaseSparseArraysExt = "SparseArrays"
38+
NonlinearSolveBaseSparseMatrixColoringsExt = "SparseMatrixColorings"
3639

3740
[compat]
3841
ADTypes = "1.9"
@@ -56,8 +59,10 @@ Markdown = "1.10"
5659
MaybeInplace = "0.1.4"
5760
RecursiveArrayTools = "3"
5861
SciMLBase = "2.50"
62+
SciMLJacobianOperators = "0.1.1"
5963
SciMLOperators = "0.3.10"
6064
SparseArrays = "1.10"
65+
SparseMatrixColorings = "0.4.8"
6166
StaticArraysCore = "1.4"
6267
Test = "1.10"
6368
julia = "1.10"

lib/NonlinearSolveBase/ext/NonlinearSolveBaseLinearSolveExt.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module NonlinearSolveBaseLinearSolveExt
33
using ArrayInterface: ArrayInterface
44
using CommonSolve: CommonSolve, init, solve!
55
using LinearAlgebra: ColumnNorm
6-
using LinearSolve: LinearSolve, QRFactorization
6+
using LinearSolve: LinearSolve, QRFactorization, SciMLLinearSolveAlgorithm
77
using NonlinearSolveBase: NonlinearSolveBase, LinearSolveJLCache, LinearSolveResult, Utils
88
using SciMLBase: ReturnCode, LinearProblem
99

@@ -81,7 +81,12 @@ function (cache::LinearSolveJLCache)(;
8181
return LinearSolveResult(; linres.u)
8282
end
8383

84-
NonlinearSolveBase.needs_square_A(linsolve, ::Any) = LinearSolve.needs_square_A(linsolve)
84+
function NonlinearSolveBase.needs_square_A(linsolve::SciMLLinearSolveAlgorithm, ::Any)
85+
return LinearSolve.needs_square_A(linsolve)
86+
end
87+
function NonlinearSolveBase.needs_concrete_A(linsolve::SciMLLinearSolveAlgorithm)
88+
return LinearSolve.needs_concrete_A(linsolve)
89+
end
8590

8691
update_A!(cache::LinearSolveJLCache, ::Nothing, reuse) = cache
8792
function update_A!(cache::LinearSolveJLCache, A, reuse)
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
module NonlinearSolveBaseSparseArraysExt
22

33
using NonlinearSolveBase: NonlinearSolveBase
4-
using SparseArrays: AbstractSparseMatrixCSC, nonzeros
4+
using SparseArrays: AbstractSparseMatrix, AbstractSparseMatrixCSC, nonzeros
55

66
function NonlinearSolveBase.NAN_CHECK(x::AbstractSparseMatrixCSC)
77
return any(NonlinearSolveBase.NAN_CHECK, nonzeros(x))
88
end
99

10+
NonlinearSolveBase.sparse_or_structured_prototype(::AbstractSparseMatrix) = true
11+
1012
end
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
module NonlinearSolveBaseSparseMatrixColoringsExt
2+
3+
using ADTypes: ADTypes, AbstractADType
4+
using NonlinearSolveBase: NonlinearSolveBase, Utils
5+
using SciMLBase: SciMLBase, NonlinearFunction
6+
using SparseMatrixColorings: ConstantColoringAlgorithm, GreedyColoringAlgorithm,
7+
LargestFirst
8+
9+
Utils.is_extension_loaded(::Val{:SparseMatrixColorings}) = true
10+
11+
function NonlinearSolveBase.select_fastest_coloring_algorithm(::Val{:SparseMatrixColorings},
12+
prototype, f::NonlinearFunction, ad::AbstractADType)
13+
prototype === nothing && return GreedyColoringAlgorithm(LargestFirst())
14+
if SciMLBase.has_colorvec(f)
15+
return ConstantColoringAlgorithm{ifelse(
16+
ADTypes.mode(ad) isa ADTypes.ReverseMode, :row, :column)}(
17+
prototype, f.colorvec)
18+
end
19+
return GreedyColoringAlgorithm(LargestFirst())
20+
end
21+
22+
end

lib/NonlinearSolveBase/src/NonlinearSolveBase.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
module NonlinearSolveBase
22

3-
using ADTypes: ADTypes, AbstractADType
3+
using ADTypes: ADTypes, AbstractADType, AutoSparse, NoSparsityDetector,
4+
KnownJacobianSparsityDetector
45
using Adapt: WrappedArray
56
using ArrayInterface: ArrayInterface
67
using CommonSolve: CommonSolve, init
78
using Compat: @compat
89
using ConcreteStructs: @concrete
9-
using DifferentiationInterface: DifferentiationInterface
10+
using DifferentiationInterface: DifferentiationInterface, Constant
1011
using EnzymeCore: EnzymeCore
1112
using FastClosures: @closure
1213
using FunctionProperties: hasbranching
@@ -17,8 +18,8 @@ using RecursiveArrayTools: AbstractVectorOfArray, ArrayPartition
1718
using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinearProblem,
1819
AbstractNonlinearAlgorithm, AbstractNonlinearFunction,
1920
NonlinearProblem, NonlinearLeastSquaresProblem, StandardNonlinearProblem,
20-
NullParameters, NLStats, LinearProblem, isinplace, warn_paramtype,
21-
@add_kwonly
21+
NonlinearFunction, NullParameters, NLStats, LinearProblem
22+
using SciMLJacobianOperators: JacobianOperator, StatefulJacobianOperator
2223
using SciMLOperators: AbstractSciMLOperator, IdentityOperator
2324
using StaticArraysCore: StaticArray, SMatrix, SArray, MArray
2425

@@ -44,9 +45,10 @@ include("linear_solve.jl")
4445
(select_forward_mode_autodiff, select_reverse_mode_autodiff,
4546
select_jacobian_autodiff))
4647

47-
# public for NonlinearSolve.jl to use
48+
# public for NonlinearSolve.jl and subpackages to use
4849
@compat(public, (InternalAPI, supports_line_search, supports_trust_region, set_du!))
49-
@compat(public, (construct_linear_solver, needs_square_A))
50+
@compat(public, (construct_linear_solver, needs_square_A, needs_concrete_A))
51+
@compat(public, (construct_jacobian_cache,))
5052

5153
export RelTerminationMode, AbsTerminationMode,
5254
NormTerminationMode, RelNormTerminationMode, AbsNormTerminationMode,

lib/NonlinearSolveBase/src/abstract_types.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module InternalAPI
22

33
function init end
44
function solve! end
5+
function reinit! end
56

67
end
78

@@ -32,14 +33,34 @@ abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
3233

3334
get_name(alg::AbstractNonlinearSolveAlgorithm) = Utils.safe_getproperty(alg, Val(:name))
3435

36+
"""
37+
concrete_jac(alg::AbstractNonlinearSolveAlgorithm)::Bool
38+
39+
Whether the algorithm uses a concrete Jacobian.
40+
"""
3541
function concrete_jac(alg::AbstractNonlinearSolveAlgorithm)
3642
return concrete_jac(Utils.safe_getproperty(alg, Val(:concrete_jac)))
3743
end
38-
concrete_jac(::Missing) = missing
44+
concrete_jac(::Missing) = false
45+
concrete_jac(::Nothing) = false
3946
concrete_jac(v::Bool) = v
4047
concrete_jac(::Val{false}) = false
4148
concrete_jac(::Val{true}) = true
4249

4350
abstract type AbstractNonlinearSolveCache end
4451

52+
"""
53+
AbstractLinearSolverCache
54+
55+
Abstract Type for all Linear Solvers used in NonlinearSolve. Subtypes of these are
56+
meant to be constructured via [`construct_linear_solver`](@ref).
57+
"""
4558
abstract type AbstractLinearSolverCache end
59+
60+
"""
61+
AbstractJacobianCache
62+
63+
Abstract Type for all Jacobian Caches used in NonlinearSolve. Subtypes of these are
64+
meant to be constructured via [`construct_jacobian_cache`](@ref).
65+
"""
66+
abstract type AbstractJacobianCache end

lib/NonlinearSolveBase/src/immutable_problem.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ struct ImmutableNonlinearProblem{uType, iip, P, F, K, PT} <:
66
problem_type::PT
77
kwargs::K
88

9-
@add_kwonly function ImmutableNonlinearProblem{iip}(
9+
SciMLBase.@add_kwonly function ImmutableNonlinearProblem{iip}(
1010
f::AbstractNonlinearFunction{iip}, u0, p = NullParameters(),
1111
problem_type = StandardNonlinearProblem(); kwargs...) where {iip}
1212
if haskey(kwargs, :p)
1313
error("`p` specified as a keyword argument `p = $(kwargs[:p])` to \
1414
`NonlinearProblem`. This is not supported.")
1515
end
16-
warn_paramtype(p)
16+
SciMLBase.warn_paramtype(p)
1717
return new{
1818
typeof(u0), iip, typeof(p), typeof(f), typeof(kwargs), typeof(problem_type)}(
1919
f, u0, p, problem_type, kwargs)
@@ -31,27 +31,26 @@ struct ImmutableNonlinearProblem{uType, iip, P, F, K, PT} <:
3131
end
3232

3333
"""
34-
Define a nonlinear problem using an instance of
35-
[`AbstractNonlinearFunction`](@ref AbstractNonlinearFunction).
34+
Define a nonlinear problem using an instance of [`AbstractNonlinearFunction`](@ref).
3635
"""
3736
function ImmutableNonlinearProblem(
3837
f::AbstractNonlinearFunction, u0, p = NullParameters(); kwargs...)
39-
return ImmutableNonlinearProblem{isinplace(f)}(f, u0, p; kwargs...)
38+
return ImmutableNonlinearProblem{SciMLBase.isinplace(f)}(f, u0, p; kwargs...)
4039
end
4140

4241
function ImmutableNonlinearProblem(f, u0, p = NullParameters(); kwargs...)
4342
return ImmutableNonlinearProblem(NonlinearFunction(f), u0, p; kwargs...)
4443
end
4544

4645
"""
47-
Define a ImmutableNonlinearProblem problem from SteadyStateProblem
46+
Define a ImmutableNonlinearProblem problem from SteadyStateProblem.
4847
"""
4948
function ImmutableNonlinearProblem(prob::AbstractNonlinearProblem)
50-
return ImmutableNonlinearProblem{isinplace(prob)}(prob.f, prob.u0, prob.p)
49+
return ImmutableNonlinearProblem{SciMLBase.isinplace(prob)}(prob.f, prob.u0, prob.p)
5150
end
5251

5352
function Base.convert(
5453
::Type{ImmutableNonlinearProblem}, prob::T) where {T <: NonlinearProblem}
55-
return ImmutableNonlinearProblem{isinplace(prob)}(
54+
return ImmutableNonlinearProblem{SciMLBase.isinplace(prob)}(
5655
prob.f, prob.u0, prob.p, prob.problem_type; prob.kwargs...)
5756
end

0 commit comments

Comments
 (0)