Skip to content

Commit c8d847e

Browse files
Merge pull request #2853 from ChrisRackauckas-Claude/sparsearrays-extension
Move SparseArrays to extension for OrdinaryDiffEqDifferentiation
2 parents ab317cd + 63a6930 commit c8d847e

File tree

6 files changed

+56
-14
lines changed

6 files changed

+56
-14
lines changed

lib/OrdinaryDiffEqDifferentiation/Project.toml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1515
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1616
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1717
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
18-
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1918
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
2019
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
2120
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
@@ -24,6 +23,12 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
2423
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
2524
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
2625

26+
[weakdeps]
27+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
28+
29+
[extensions]
30+
OrdinaryDiffEqDifferentiationSparseArraysExt = "SparseArrays"
31+
2732
[extras]
2833
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
2934
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
@@ -32,6 +37,7 @@ DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
3237
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3338
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
3439
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
40+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
3541

3642
[compat]
3743
ForwardDiff = "0.10.38, 1"
@@ -48,8 +54,8 @@ ConstructionBase = "1.5.8"
4854
LinearAlgebra = "1.10"
4955
SciMLBase = "2.99"
5056
OrdinaryDiffEqCore = "1.29.0"
51-
SparseArrays = "1.10"
5257
ConcreteStructs = "0.2"
58+
SparseArrays = "1.10"
5359
Aqua = "0.8.11"
5460
ArrayInterface = "7.19"
5561
StaticArrays = "1.9"
@@ -63,7 +69,7 @@ SafeTestsets = "0.1.0"
6369
SciMLOperators = "1.4"
6470

6571
[targets]
66-
test = ["DiffEqDevTools", "Random", "SafeTestsets", "Test", "JET", "Aqua", "AllocCheck"]
72+
test = ["DiffEqDevTools", "Random", "SafeTestsets", "Test", "JET", "Aqua", "AllocCheck", "SparseArrays"]
6773

6874
[sources.OrdinaryDiffEqCore]
6975
path = "../OrdinaryDiffEqCore"
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
module OrdinaryDiffEqDifferentiationSparseArraysExt
2+
3+
using OrdinaryDiffEqDifferentiation
4+
import SparseArrays
5+
import SparseArrays: nonzeros, spzeros, SparseMatrixCSC, AbstractSparseMatrix
6+
7+
# Override the sparse checking functions
8+
OrdinaryDiffEqDifferentiation.is_sparse(::AbstractSparseMatrix) = true
9+
OrdinaryDiffEqDifferentiation.is_sparse_csc(::SparseMatrixCSC) = true
10+
11+
# Override the sparse array manipulation functions
12+
OrdinaryDiffEqDifferentiation.nonzeros(A::AbstractSparseMatrix) = nonzeros(A)
13+
OrdinaryDiffEqDifferentiation.spzeros(T::Type, m::Integer, n::Integer) = spzeros(T, m, n)
14+
15+
# Helper functions for accessing sparse matrix internals
16+
OrdinaryDiffEqDifferentiation.get_nzval(A::SparseMatrixCSC) = A.nzval
17+
OrdinaryDiffEqDifferentiation.set_all_nzval!(A::SparseMatrixCSC, val) = (A.nzval .= val; A)
18+
19+
end

lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import DiffEqBase
1313
import LinearAlgebra
1414
import LinearAlgebra: Diagonal, I, UniformScaling, diagind, mul!, lmul!, axpby!, opnorm, lu
1515
import LinearAlgebra: LowerTriangular, UpperTriangular
16-
import SparseArrays: SparseMatrixCSC, AbstractSparseMatrix, nonzeros, sparse, spzeros
1716
import ArrayInterface
1817
import ArrayInterface: fast_scalar_indexing, zeromatrix, lu_instance
1918

@@ -67,6 +66,24 @@ else
6766
struct OrdinaryDiffEqTag end
6867
end
6968

69+
# Functions for sparse array handling - will be overloaded by extension
70+
# Default implementations return false/error for non-sparse types
71+
is_sparse(::Any) = false
72+
is_sparse_csc(::Any) = false
73+
74+
# These will error if called without the extension, but should never be called
75+
# on non-sparse types due to the is_sparse checks
76+
function nonzeros end
77+
function spzeros end
78+
function get_nzval end
79+
function set_all_nzval! end
80+
81+
# Provide error messages if these are called without extension
82+
nonzeros(A) = error("SparseArrays extension not loaded. Please load SparseArrays to use sparse matrix functionality.")
83+
spzeros(args...) = error("SparseArrays extension not loaded. Please load SparseArrays to use sparse matrix functionality.")
84+
get_nzval(A) = error("SparseArrays extension not loaded. Please load SparseArrays to use sparse matrix functionality.")
85+
set_all_nzval!(A, val) = error("SparseArrays extension not loaded. Please load SparseArrays to use sparse matrix functionality.")
86+
7087
include("alg_utils.jl")
7188
include("linsolve_utils.jl")
7289
include("derivative_utils.jl")

lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ function prepare_user_sparsity(ad_alg, prob)
108108
sparsity = prob.f.sparsity
109109

110110
if !isnothing(sparsity) && !(ad_alg isa AutoSparse)
111-
if sparsity isa SparseMatrixCSC && !SciMLBase.has_jac(prob.f)
111+
if is_sparse_csc(sparsity) && !SciMLBase.has_jac(prob.f)
112112
if prob.f.mass_matrix isa UniformScaling
113113
idxs = diagind(sparsity)
114114
@. @view(sparsity[idxs]) = 1

lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,10 @@ function calc_J!(J, integrator, cache, next_step::Bool = false)
179179
# we need to set all nzval to a non-zero number
180180
# otherwise in the following line any zero gets interpreted as a structural zero
181181
if !isnothing(integrator.f.jac_prototype) &&
182-
integrator.f.jac_prototype isa SparseMatrixCSC
183-
integrator.f.jac_prototype.nzval .= true
182+
is_sparse_csc(integrator.f.jac_prototype)
183+
set_all_nzval!(integrator.f.jac_prototype, true)
184184
J .= true .* integrator.f.jac_prototype
185-
J.nzval .= false
185+
set_all_nzval!(J, false)
186186
f.jac(J, duprev, uprev, p, uf.α * uf.invγdt, t)
187187
else
188188
f.jac(J, duprev, uprev, p, uf.α * uf.invγdt, t)
@@ -203,10 +203,10 @@ function calc_J!(J, integrator, cache, next_step::Bool = false)
203203
# we need to set all nzval to a non-zero number
204204
# otherwise in the following line any zero gets interpreted as a structural zero
205205
if !isnothing(integrator.f.jac_prototype) &&
206-
integrator.f.jac_prototype isa SparseMatrixCSC
207-
integrator.f.jac_prototype.nzval .= true
206+
is_sparse_csc(integrator.f.jac_prototype)
207+
set_all_nzval!(integrator.f.jac_prototype, true)
208208
J .= true .* integrator.f.jac_prototype
209-
J.nzval .= false
209+
set_all_nzval!(J, false)
210210
f.jac(J, uprev, p, t)
211211
else
212212
f.jac(J, uprev, p, t)
@@ -278,7 +278,7 @@ mutable struct WOperator{IIP, T,
278278
if AJ isa AbstractMatrix
279279
mm = mass_matrix isa MatrixOperator ?
280280
convert(AbstractMatrix, mass_matrix) : mass_matrix
281-
if AJ isa AbstractSparseMatrix
281+
if is_sparse(AJ)
282282

283283
# If gamma is zero, then it's just an initialization and we want to make sure
284284
# we get the right sparsity pattern. If gamma is not zero, then it's a case where

lib/OrdinaryDiffEqDifferentiation/src/derivative_wrappers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ function build_jac_config(alg, f::F1, uf::F2, du1, uprev,
249249
(concrete_jac(alg) !== nothing && concrete_jac(alg)))
250250
jac_prototype = f.jac_prototype
251251

252-
if jac_prototype isa SparseMatrixCSC
252+
if is_sparse_csc(jac_prototype)
253253
if f.mass_matrix isa UniformScaling
254254
idxs = diagind(jac_prototype)
255255
@. @view(jac_prototype[idxs]) = 1
@@ -396,7 +396,7 @@ end
396396
function sparsity_colorvec(f, x)
397397
sparsity = f.sparsity
398398

399-
if sparsity isa SparseMatrixCSC
399+
if is_sparse_csc(sparsity)
400400
if f.mass_matrix isa UniformScaling
401401
idxs = diagind(sparsity)
402402
@. @view(sparsity[idxs]) = 1

0 commit comments

Comments
 (0)