Skip to content

Commit 616d8ff

Browse files
Merge pull request #192 from avik-pal/ap/sparsearr
refactor: move `SparseArrays` into an extension
2 parents b574440 + 884422b commit 616d8ff

File tree

5 files changed

+72
-66
lines changed

5 files changed

+72
-66
lines changed

Project.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
name = "FiniteDiff"
22
uuid = "6a86dc24-6348-571c-b903-95158fe2bd41"
3-
version = "2.24.0"
3+
version = "2.25.0"
44

55
[deps]
66
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
9-
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
10-
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
119

1210
[weakdeps]
1311
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
1412
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
13+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1514
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1615

1716
[extensions]
1817
FiniteDiffBandedMatricesExt = "BandedMatrices"
1918
FiniteDiffBlockBandedMatricesExt = "BlockBandedMatrices"
19+
FiniteDiffSparseArraysExt = "SparseArrays"
2020
FiniteDiffStaticArraysExt = "StaticArrays"
2121

2222
[compat]
@@ -32,8 +32,9 @@ BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
3232
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
3333
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
3434
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
35+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
3536
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
3637
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3738

3839
[targets]
39-
test = ["Test", "BlockBandedMatrices", "BandedMatrices", "Pkg", "SafeTestsets", "StaticArrays"]
40+
test = ["Test", "BlockBandedMatrices", "BandedMatrices", "Pkg", "SafeTestsets", "SparseArrays", "StaticArrays"]

ext/FiniteDiffSparseArraysExt.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
module FiniteDiffSparseArraysExt
2+
3+
using SparseArrays
4+
using FiniteDiff
5+
6+
# jacobians.jl
7+
function FiniteDiff._make_Ji(::SparseMatrixCSC, rows_index, cols_index, dx, colorvec, color_i, nrows, ncols)
8+
pick_inds = [i for i in 1:length(rows_index) if colorvec[cols_index[i]] == color_i]
9+
rows_index_c = rows_index[pick_inds]
10+
cols_index_c = cols_index[pick_inds]
11+
Ji = sparse(rows_index_c, cols_index_c, dx[rows_index_c], nrows, ncols)
12+
Ji
13+
end
14+
15+
function FiniteDiff._make_Ji(::SparseMatrixCSC, xtype, dx, color_i, nrows, ncols)
16+
Ji = sparse(1:nrows, fill(color_i, nrows), dx, nrows, ncols)
17+
Ji
18+
end
19+
20+
@inline function FiniteDiff._colorediteration!(J, sparsity::SparseMatrixCSC, rows_index, cols_index, vfx, colorvec, color_i, ncols)
21+
@inbounds for col_index in 1:ncols
22+
if colorvec[col_index] == color_i
23+
@inbounds for row_index in view(sparsity.rowval, sparsity.colptr[col_index]:sparsity.colptr[col_index+1]-1)
24+
J[row_index, col_index] = vfx[row_index]
25+
end
26+
end
27+
end
28+
end
29+
30+
@inline FiniteDiff.fill_matrix!(J::AbstractSparseMatrix, v) = fill!(nonzeros(J), v)
31+
32+
@inline function FiniteDiff.fast_jacobian_setindex!(J::AbstractSparseMatrix, rows_index, cols_index, _color, color_i, vfx)
33+
@. FiniteDiff.void_setindex!((J.nzval,), getindex((J.nzval,), rows_index) + (getindex((_color,), cols_index) == color_i) * getindex((vfx,), rows_index), rows_index)
34+
end
35+
36+
# iteration_utils.jl
37+
## fast version for the case where J and sparsity have the same sparsity pattern
38+
@inline function FiniteDiff._colorediteration!(Jsparsity::SparseMatrixCSC, vfx, colorvec, color_i, ncols)
39+
@inbounds for col_index in 1:ncols
40+
if colorvec[col_index] == color_i
41+
@inbounds for spidx in nzrange(Jsparsity, col_index)
42+
row_index = Jsparsity.rowval[spidx]
43+
Jsparsity.nzval[spidx] = vfx[row_index]
44+
end
45+
end
46+
end
47+
end
48+
49+
FiniteDiff._use_findstructralnz(::SparseMatrixCSC) = false
50+
51+
FiniteDiff._use_sparseCSC_common_sparsity(J::SparseMatrixCSC, sparsity::SparseMatrixCSC) =
52+
((J.colptr == sparsity.colptr) && (J.rowval == sparsity.rowval))
53+
54+
55+
end

src/FiniteDiff.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Fast non-allocating calculations of gradients, Jacobians, and Hessians with spar
55
"""
66
module FiniteDiff
77

8-
using LinearAlgebra, SparseArrays, ArrayInterface
8+
using LinearAlgebra, ArrayInterface
99

1010
import Base: resize!
1111

src/iteration_utils.jl

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,8 @@
66
end
77
end
88

9-
@inline function _colorediteration!(J,sparsity::SparseMatrixCSC,rows_index,cols_index,vfx,colorvec,color_i,ncols)
10-
@inbounds for col_index in 1:ncols
11-
if colorvec[col_index] == color_i
12-
@inbounds for row_index in view(sparsity.rowval,sparsity.colptr[col_index]:sparsity.colptr[col_index+1]-1)
13-
J[row_index,col_index]=vfx[row_index]
14-
end
15-
end
16-
end
17-
end
18-
19-
# fast version for the case where J and sparsity have the same sparsity pattern
20-
@inline function _colorediteration!(Jsparsity::SparseMatrixCSC,vfx,colorvec,color_i,ncols)
21-
@inbounds for col_index in 1:ncols
22-
if colorvec[col_index] == color_i
23-
@inbounds for spidx in nzrange(Jsparsity, col_index)
24-
row_index = Jsparsity.rowval[spidx]
25-
Jsparsity.nzval[spidx]=vfx[row_index]
26-
end
27-
end
28-
end
29-
end
30-
319
#override default setting of using findstructralnz
3210
_use_findstructralnz(sparsity) = ArrayInterface.has_sparsestruct(sparsity)
33-
_use_findstructralnz(::SparseMatrixCSC) = false
3411

3512
# test if J, sparsity are both SparseMatrixCSC and have the same sparsity pattern of stored values
3613
_use_sparseCSC_common_sparsity(J, sparsity) = false
37-
_use_sparseCSC_common_sparsity(J::SparseMatrixCSC, sparsity::SparseMatrixCSC) =
38-
((J.colptr == sparsity.colptr) && (J.rowval == sparsity.rowval))
39-

src/jacobians.jl

Lines changed: 11 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -125,14 +125,6 @@ function JacobianCache(
125125
JacobianCache{typeof(_x1),typeof(_x2),typeof(_fx),typeof(fx1),typeof(colorvec),typeof(sparsity),fdtype,returntype}(_x1,_x2,_fx,fx1,colorvec,sparsity)
126126
end
127127

128-
function _make_Ji(::SparseMatrixCSC, rows_index,cols_index,dx,colorvec,color_i,nrows,ncols)
129-
pick_inds = [i for i in 1:length(rows_index) if colorvec[cols_index[i]] == color_i]
130-
rows_index_c = rows_index[pick_inds]
131-
cols_index_c = cols_index[pick_inds]
132-
Ji = sparse(rows_index_c, cols_index_c, dx[rows_index_c],nrows,ncols)
133-
Ji
134-
end
135-
136128
function _make_Ji(::AbstractArray, rows_index,cols_index,dx,colorvec,color_i,nrows,ncols)
137129
pick_inds = [i for i in 1:length(rows_index) if colorvec[cols_index[i]] == color_i]
138130
rows_index_c = rows_index[pick_inds]
@@ -145,12 +137,6 @@ function _make_Ji(::AbstractArray, rows_index,cols_index,dx,colorvec,color_i,nro
145137
Ji
146138
end
147139

148-
function _make_Ji(::SparseMatrixCSC, xtype, dx, color_i, nrows, ncols)
149-
Ji = sparse(1:nrows,fill(color_i,nrows),dx,nrows,ncols)
150-
Ji
151-
end
152-
153-
154140
function _make_Ji(::AbstractArray, xtype, dx, color_i, nrows, ncols)
155141
Ji = mapreduce(i -> i==color_i ? dx : zero(dx), hcat, 1:ncols)
156142
size(Ji) != (nrows, ncols) ? reshape(Ji, (nrows, ncols)) : Ji #branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1)
@@ -445,11 +431,7 @@ function finite_difference_jacobian!(
445431
end
446432

447433
if sparsity !== nothing
448-
if J isa AbstractSparseMatrix
449-
fill!(nonzeros(J),false)
450-
else
451-
fill!(J,false)
452-
end
434+
fill_matrix!(J, false)
453435
end
454436

455437
# fast path if J and sparsity are both AbstractSparseMatrix and have the same sparsity pattern
@@ -497,11 +479,7 @@ function finite_difference_jacobian!(
497479
J[rows_index, cols_index] .+= (colorvec[cols_index] .== color_i) .* vfx1[rows_index]
498480
+= means requires a zero'd out start
499481
=#
500-
if J isa AbstractSparseMatrix
501-
@. void_setindex!((J.nzval,), getindex((J.nzval,), rows_index) + (getindex((_color,), cols_index) == color_i) * getindex((vfx1,), rows_index), rows_index)
502-
else
503-
@. void_setindex!((J,), getindex((J,), rows_index, cols_index) + (getindex((_color,), cols_index) == color_i) * getindex((vfx1,), rows_index), rows_index, cols_index)
504-
end
482+
fast_jacobian_setindex!(J, rows_index, cols_index, _color, color_i, vfx1)
505483
end
506484
# Now return x1 back to its original value
507485
@. x1 = x1 - epsilon * (_color == color_i)
@@ -535,11 +513,7 @@ function finite_difference_jacobian!(
535513
_colorediteration!(J,sparsity,rows_index,cols_index,vfx1,colorvec,color_i,n)
536514
end
537515
else
538-
if J isa AbstractSparseMatrix
539-
@. void_setindex!((J.nzval,), getindex((J.nzval,), rows_index) + (getindex((_color,), cols_index) == color_i) * getindex((vfx1,), rows_index), rows_index)
540-
else
541-
@. void_setindex!((J,), getindex((J,), rows_index, cols_index) + (getindex((_color,), cols_index) == color_i) * getindex((vfx1,), rows_index), rows_index, cols_index)
542-
end
516+
fast_jacobian_setindex!(J, rows_index, cols_index, _color, color_i, vfx1)
543517
end
544518
@. x1 = x1 - epsilon * (_color == color_i)
545519
@. x = x + epsilon * (_color == color_i)
@@ -565,11 +539,7 @@ function finite_difference_jacobian!(
565539
_colorediteration!(J,sparsity,rows_index,cols_index,vfx,colorvec,color_i,n)
566540
end
567541
else
568-
if J isa AbstractSparseMatrix
569-
@. void_setindex!((J.nzval,), getindex((J.nzval,), rows_index) + (getindex((_color,), cols_index) == color_i) * getindex((vfx,),rows_index), rows_index)
570-
else
571-
@. void_setindex!((J,), getindex((J,), rows_index, cols_index) + (getindex((_color,), cols_index) == color_i) * getindex((vfx,), rows_index), rows_index, cols_index)
572-
end
542+
fast_jacobian_setindex!(J, rows_index, cols_index, _color, color_i, vfx)
573543
end
574544
@. x1 = x1 - im * epsilon * (_color == color_i)
575545
end
@@ -583,7 +553,13 @@ end
583553
function resize!(cache::JacobianCache, i::Int)
584554
resize!(cache.x1, i)
585555
resize!(cache.fx, i)
586-
cache.fx1 != nothing && resize!(cache.fx1, i)
556+
cache.fx1 !== nothing && resize!(cache.fx1, i)
587557
cache.colorvec = 1:i
588558
nothing
589559
end
560+
561+
@inline fill_matrix!(J, v) = fill!(J, v)
562+
563+
@inline function fast_jacobian_setindex!(J, rows_index, cols_index, _color, color_i, vfx)
564+
@. void_setindex!((J,), getindex((J,), rows_index, cols_index) + (getindex((_color,), cols_index) == color_i) * getindex((vfx,), rows_index), rows_index, cols_index)
565+
end

0 commit comments

Comments
 (0)