Skip to content

Commit 63a6930

Browse files
committed
Refactor to use function-based dispatch for sparse array checks
Per review feedback, replaced all `isa SparseMatrixCSC` and `isa AbstractSparseMatrix` checks with function calls (`is_sparse_csc()` and `is_sparse()`) that default to false. This allows the sparse code to completely compile out when SparseArrays is not loaded, improving performance for non-sparse cases. Changes: - Replaced type checks with is_sparse() and is_sparse_csc() function calls - Functions return false by default for any input - Extension overloads them to return true for sparse types - Removed unnecessary type aliases - All sparse-specific code paths are now behind function checks that compile out 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 3c1d674 commit 63a6930

File tree

5 files changed

+23
-32
lines changed

5 files changed

+23
-32
lines changed

lib/OrdinaryDiffEqDifferentiation/ext/OrdinaryDiffEqDifferentiationSparseArraysExt.jl

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,18 @@ module OrdinaryDiffEqDifferentiationSparseArraysExt
22

33
using OrdinaryDiffEqDifferentiation
44
import SparseArrays
5-
import SparseArrays: nonzeros, spzeros
5+
import SparseArrays: nonzeros, spzeros, SparseMatrixCSC, AbstractSparseMatrix
66

7-
# Set the type aliases when extension loads
8-
function __init__()
9-
OrdinaryDiffEqDifferentiation.SparseMatrixCSC = SparseArrays.SparseMatrixCSC
10-
OrdinaryDiffEqDifferentiation.AbstractSparseMatrix = SparseArrays.AbstractSparseMatrix
11-
end
7+
# Override the sparse checking functions
8+
OrdinaryDiffEqDifferentiation.is_sparse(::AbstractSparseMatrix) = true
9+
OrdinaryDiffEqDifferentiation.is_sparse_csc(::SparseMatrixCSC) = true
1210

13-
# Define functions that were previously imported directly
14-
OrdinaryDiffEqDifferentiation.nonzeros(A::SparseArrays.AbstractSparseMatrix) = nonzeros(A)
11+
# Override the sparse array manipulation functions
12+
OrdinaryDiffEqDifferentiation.nonzeros(A::AbstractSparseMatrix) = nonzeros(A)
1513
OrdinaryDiffEqDifferentiation.spzeros(T::Type, m::Integer, n::Integer) = spzeros(T, m, n)
1614

17-
# Helper function to check if a type is sparse
18-
OrdinaryDiffEqDifferentiation.is_sparse_type(::Type{<:SparseArrays.SparseMatrixCSC}) = true
19-
OrdinaryDiffEqDifferentiation.is_sparse_type(::SparseArrays.SparseMatrixCSC) = true
20-
OrdinaryDiffEqDifferentiation.is_sparse_type(::Type{<:SparseArrays.AbstractSparseMatrix}) = true
21-
OrdinaryDiffEqDifferentiation.is_sparse_type(::SparseArrays.AbstractSparseMatrix) = true
22-
2315
# Helper functions for accessing sparse matrix internals
24-
OrdinaryDiffEqDifferentiation.get_nzval(A::SparseArrays.SparseMatrixCSC) = A.nzval
25-
OrdinaryDiffEqDifferentiation.set_all_nzval!(A::SparseArrays.SparseMatrixCSC, val) = (A.nzval .= val; A)
16+
OrdinaryDiffEqDifferentiation.get_nzval(A::SparseMatrixCSC) = A.nzval
17+
OrdinaryDiffEqDifferentiation.set_all_nzval!(A::SparseMatrixCSC, val) = (A.nzval .= val; A)
2618

2719
end

lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,22 +66,21 @@ else
6666
struct OrdinaryDiffEqTag end
6767
end
6868

69-
# Stub types and functions that will be defined by the SparseArrays extension
70-
# These are needed for type stability and method dispatch
71-
# Use a mutable placeholder that can be set by the extension
72-
SparseMatrixCSC = Any
73-
AbstractSparseMatrix = Any
69+
# Functions for sparse array handling - will be overloaded by extension
70+
# Default implementations return false/error for non-sparse types
71+
is_sparse(::Any) = false
72+
is_sparse_csc(::Any) = false
7473

75-
# Stub functions that will be overridden by the extension
74+
# These will error if called without the extension, but should never be called
75+
# on non-sparse types due to the is_sparse checks
7676
function nonzeros end
7777
function spzeros end
78-
function is_sparse_type end
7978
function get_nzval end
8079
function set_all_nzval! end
8180

82-
# Default implementations for non-sparse types
83-
is_sparse_type(::Type) = false
84-
is_sparse_type(::Any) = false
81+
# Provide error messages if these are called without extension
82+
nonzeros(A) = error("SparseArrays extension not loaded. Please load SparseArrays to use sparse matrix functionality.")
83+
spzeros(args...) = error("SparseArrays extension not loaded. Please load SparseArrays to use sparse matrix functionality.")
8584
get_nzval(A) = error("SparseArrays extension not loaded. Please load SparseArrays to use sparse matrix functionality.")
8685
set_all_nzval!(A, val) = error("SparseArrays extension not loaded. Please load SparseArrays to use sparse matrix functionality.")
8786

lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ function prepare_user_sparsity(ad_alg, prob)
108108
sparsity = prob.f.sparsity
109109

110110
if !isnothing(sparsity) && !(ad_alg isa AutoSparse)
111-
if sparsity isa SparseMatrixCSC && !SciMLBase.has_jac(prob.f)
111+
if is_sparse_csc(sparsity) && !SciMLBase.has_jac(prob.f)
112112
if prob.f.mass_matrix isa UniformScaling
113113
idxs = diagind(sparsity)
114114
@. @view(sparsity[idxs]) = 1

lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ function calc_J!(J, integrator, cache, next_step::Bool = false)
179179
# we need to set all nzval to a non-zero number
180180
# otherwise in the following line any zero gets interpreted as a structural zero
181181
if !isnothing(integrator.f.jac_prototype) &&
182-
integrator.f.jac_prototype isa SparseMatrixCSC
182+
is_sparse_csc(integrator.f.jac_prototype)
183183
set_all_nzval!(integrator.f.jac_prototype, true)
184184
J .= true .* integrator.f.jac_prototype
185185
set_all_nzval!(J, false)
@@ -203,7 +203,7 @@ function calc_J!(J, integrator, cache, next_step::Bool = false)
203203
# we need to set all nzval to a non-zero number
204204
# otherwise in the following line any zero gets interpreted as a structural zero
205205
if !isnothing(integrator.f.jac_prototype) &&
206-
integrator.f.jac_prototype isa SparseMatrixCSC
206+
is_sparse_csc(integrator.f.jac_prototype)
207207
set_all_nzval!(integrator.f.jac_prototype, true)
208208
J .= true .* integrator.f.jac_prototype
209209
set_all_nzval!(J, false)
@@ -278,7 +278,7 @@ mutable struct WOperator{IIP, T,
278278
if AJ isa AbstractMatrix
279279
mm = mass_matrix isa MatrixOperator ?
280280
convert(AbstractMatrix, mass_matrix) : mass_matrix
281-
if AJ isa AbstractSparseMatrix
281+
if is_sparse(AJ)
282282

283283
# If gamma is zero, then it's just an initialization and we want to make sure
284284
# we get the right sparsity pattern. If gamma is not zero, then it's a case where

lib/OrdinaryDiffEqDifferentiation/src/derivative_wrappers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ function build_jac_config(alg, f::F1, uf::F2, du1, uprev,
249249
(concrete_jac(alg) !== nothing && concrete_jac(alg)))
250250
jac_prototype = f.jac_prototype
251251

252-
if jac_prototype isa SparseMatrixCSC
252+
if is_sparse_csc(jac_prototype)
253253
if f.mass_matrix isa UniformScaling
254254
idxs = diagind(jac_prototype)
255255
@. @view(jac_prototype[idxs]) = 1
@@ -396,7 +396,7 @@ end
396396
function sparsity_colorvec(f, x)
397397
sparsity = f.sparsity
398398

399-
if sparsity isa SparseMatrixCSC
399+
if is_sparse_csc(sparsity)
400400
if f.mass_matrix isa UniformScaling
401401
idxs = diagind(sparsity)
402402
@. @view(sparsity[idxs]) = 1

0 commit comments

Comments
 (0)