Skip to content

Commit 07051a1

Browse files
committed
Faster sparse matrix construction
1 parent 8544436 commit 07051a1

File tree

5 files changed

+37
-35
lines changed

5 files changed

+37
-35
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ArrayDiff"
22
uuid = "c45fa1ca-6901-44ac-ae5b-5513a4852d50"
3-
authors = ["Benoît Legat <benoit.legat@gmail.com>"]
43
version = "0.1.0"
4+
authors = ["Benoît Legat <benoit.legat@gmail.com>"]
55

66
[deps]
77
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"

src/ArrayDiff.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,26 @@
66

77
module ArrayDiff
88

9+
import LinearAlgebra as LA
10+
import SparseArrays
11+
import SparseMatrixColorings as SMC
912
import ForwardDiff
1013
import MathOptInterface as MOI
1114
const Nonlinear = MOI.Nonlinear
12-
import SparseArrays
13-
import SparseMatrixColorings
1415

1516
"""
16-
Mode(coloring_algorithm::SparseMatrixColorings.GreedyColoringAlgorithm) <: AbstractAutomaticDifferentiation
17+
Mode(coloring_algorithm::SMC.GreedyColoringAlgorithm) <: AbstractAutomaticDifferentiation
1718
1819
Fork of `MOI.Nonlinear.SparseReverseMode` to add array support.
1920
"""
20-
struct Mode{C<:SparseMatrixColorings.GreedyColoringAlgorithm} <:
21+
struct Mode{C<:SMC.GreedyColoringAlgorithm} <:
2122
MOI.Nonlinear.AbstractAutomaticDifferentiation
2223
coloring_algorithm::C
2324
end
2425

2526
function Mode()
2627
return Mode(
27-
SparseMatrixColorings.GreedyColoringAlgorithm(;
28+
SMC.GreedyColoringAlgorithm(;
2829
decompression = :substitution,
2930
),
3031
)

src/coloring.jl

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66

77
"""
88
struct ColoringResult
9-
result::SparseMatrixColorings.TreeSetColoringResult
9+
result::SMC.TreeSetColoringResult
1010
local_indices::Vector{Int} # map from local to global indices
1111
end
1212
1313
Wrapper around TreeSetColoringResult that also stores local_indices mapping.
1414
"""
15-
struct ColoringResult{R<:SparseMatrixColorings.AbstractColoringResult}
15+
struct ColoringResult{R<:SMC.AbstractColoringResult}
1616
result::R
1717
local_indices::Vector{Int} # map from local to global indices
1818
end
@@ -21,7 +21,7 @@ end
2121
_hessian_color_preprocess(
2222
edgelist,
2323
num_total_var,
24-
algo::SparseMatrixColorings.GreedyColoringAlgorithm,
24+
algo::SMC.GreedyColoringAlgorithm,
2525
seen_idx = MOI.Nonlinear.Coloring.IndexedSet(0),
2626
)
2727
@@ -35,7 +35,7 @@ SparseMatrixColorings.
3535
function _hessian_color_preprocess(
3636
edgelist,
3737
num_total_var,
38-
algo::SparseMatrixColorings.GreedyColoringAlgorithm,
38+
algo::SMC.GreedyColoringAlgorithm,
3939
seen_idx = MOI.Nonlinear.Coloring.IndexedSet(0),
4040
)
4141
resize!(seen_idx, num_total_var)
@@ -45,6 +45,10 @@ function _hessian_color_preprocess(
4545
push!(seen_idx, j)
4646
push!(I, i)
4747
push!(J, j)
48+
if i != j
49+
push!(I, j)
50+
push!(J, i)
51+
end
4852
end
4953
local_indices = sort!(collect(seen_idx))
5054
empty!(seen_idx)
@@ -56,12 +60,12 @@ function _hessian_color_preprocess(
5660
# The I and J vectors are already empty, which is correct
5761
# For the result, we'll create a minimal valid structure with a diagonal element
5862
# Note: This case should rarely occur in practice
59-
S = SparseArrays.spdiagm(0 => [true])
60-
problem = SparseMatrixColorings.ColoringProblem(;
63+
S = SMC.SparsityPatternCSC(SparseArrays.spdiagm(0 => [true]))
64+
problem = SMC.ColoringProblem(;
6165
structure = :symmetric,
6266
partition = :column,
6367
)
64-
tree_result = SparseMatrixColorings.coloring(S, problem, algo)
68+
tree_result = SMC.coloring(S, problem, algo)
6569
result = ColoringResult(tree_result, Int[])
6670
return I, J, result
6771
end
@@ -78,19 +82,16 @@ function _hessian_color_preprocess(
7882

7983
# Create sparsity pattern matrix
8084
n = length(local_indices)
81-
S = SparseArrays.spzeros(Bool, n, n)
82-
for k in eachindex(I)
83-
i, j = I[k], J[k]
84-
S[i, j] = true
85-
S[j, i] = true # symmetric
86-
end
85+
S = SMC.SparsityPatternCSC(
86+
SparseArrays.sparse(I, J, trues(length(I)), n, n, &)
87+
)
8788

88-
# Perform coloring using SparseMatrixColorings
89-
problem = SparseMatrixColorings.ColoringProblem(;
89+
# Perform coloring using SMC
90+
problem = SMC.ColoringProblem(;
9091
structure = :symmetric,
9192
partition = :column,
9293
)
93-
tree_result = SparseMatrixColorings.coloring(S, problem, algo)
94+
tree_result = SMC.coloring(S, problem, algo)
9495

9596
# Reconstruct I and J from the tree structure (matching original _indirect_recover_structure)
9697
# First add all diagonal elements
@@ -148,7 +149,7 @@ Allocate a seed matrix for the coloring result.
148149
"""
149150
function _seed_matrix(result::ColoringResult)
150151
n = length(result.local_indices)
151-
ncolors = SparseMatrixColorings.ncolors(result.result)
152+
ncolors = SMC.ncolors(result.result)
152153
return Matrix{Float64}(undef, n, ncolors)
153154
end
154155

@@ -158,10 +159,10 @@ end
158159
Prepare the seed matrix R for Hessian computation.
159160
"""
160161
function _prepare_seed_matrix!(R, result::ColoringResult)
161-
color = SparseMatrixColorings.column_colors(result.result)
162+
color = SMC.column_colors(result.result)
162163
N = length(result.local_indices)
163164
@assert N == size(R, 1)
164-
@assert size(R, 2) == SparseMatrixColorings.ncolors(result.result)
165+
@assert size(R, 2) == SMC.ncolors(result.result)
165166
fill!(R, 0.0)
166167
for i in 1:N
167168
if color[i] > 0
@@ -190,7 +191,7 @@ function _recover_from_matmat!(
190191
stored_values::AbstractVector{T},
191192
) where {T}
192193
tree_result = result.result
193-
color = SparseMatrixColorings.column_colors(tree_result)
194+
color = SMC.column_colors(tree_result)
194195
N = length(result.local_indices)
195196
# Compute number of off-diagonal nonzeros from the length of V
196197
# V contains N diagonal elements + nnz_offdiag off-diagonal elements

src/types.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ function _subexpression_and_linearity(
6363
linearity
6464
end
6565

66-
struct _FunctionStorage{R<:SparseMatrixColorings.AbstractColoringResult}
66+
struct _FunctionStorage{R<:SMC.AbstractColoringResult}
6767
expr::_SubexpressionStorage
6868
grad_sparsity::Vector{Int}
6969
# Nonzero pattern of Hessian matrix
@@ -80,7 +80,7 @@ struct _FunctionStorage{R<:SparseMatrixColorings.AbstractColoringResult}
8080
coloring_storage::MOI.Nonlinear.ReverseAD.Coloring.IndexedSet,
8181
coloring_algorithm::Union{
8282
Nothing,
83-
SparseMatrixColorings.GreedyColoringAlgorithm,
83+
SMC.GreedyColoringAlgorithm,
8484
},
8585
subexpressions::Vector{_SubexpressionStorage},
8686
dependent_subexpressions,
@@ -141,7 +141,7 @@ end
141141
NLPEvaluator(
142142
model::Nonlinear.Model,
143143
ordered_variables::Vector{MOI.VariableIndex},
144-
coloring_algorithm::SparseMatrixColorings.AbstractColoringAlgorithm = SparseMatrixColorings.GreedyColoringAlgorithm(; decompression=:substitution),
144+
coloring_algorithm::SMC.AbstractColoringAlgorithm = SMC.GreedyColoringAlgorithm(; decompression=:substitution),
145145
)
146146
147147
Return an `NLPEvaluator` object that implements the `MOI.AbstractNLPEvaluator`
@@ -152,7 +152,7 @@ interface.
152152
"""
153153
mutable struct NLPEvaluator{
154154
R,
155-
C<:SparseMatrixColorings.GreedyColoringAlgorithm,
155+
C<:SMC.GreedyColoringAlgorithm,
156156
} <: MOI.AbstractNLPEvaluator
157157
data::Nonlinear.Model
158158
ordered_variables::Vector{MOI.VariableIndex}
@@ -193,18 +193,18 @@ mutable struct NLPEvaluator{
193193
function NLPEvaluator(
194194
data::Nonlinear.Model,
195195
ordered_variables::Vector{MOI.VariableIndex},
196-
coloring_algorithm::SparseMatrixColorings.GreedyColoringAlgorithm = SparseMatrixColorings.GreedyColoringAlgorithm(;
196+
coloring_algorithm::SMC.GreedyColoringAlgorithm = SMC.GreedyColoringAlgorithm(;
197197
decompression = :substitution,
198198
),
199199
)
200-
problem = SparseMatrixColorings.ColoringProblem(;
200+
problem = SMC.ColoringProblem(;
201201
structure = :symmetric,
202202
partition = :column,
203203
)
204204
C = typeof(coloring_algorithm)
205205
R = Base.promote_op(
206-
SparseMatrixColorings.coloring,
207-
SparseArrays.SparseMatrixCSC{Bool,Int},
206+
SMC.coloring,
207+
SMC.SparsityPatternCSC{Int},
208208
typeof(problem),
209209
C,
210210
)

test/ReverseAD.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ end
481481
function test_coloring_end_to_end_hessian_coloring_and_recovery()
482482
# Test the new coloring API through the compatibility layer
483483
coloring_algorithm =
484-
ArrayDiff.SparseMatrixColorings.GreedyColoringAlgorithm(;
484+
ArrayDiff.SMC.GreedyColoringAlgorithm(;
485485
decompression = :substitution,
486486
)
487487
I, J, rinfo = ArrayDiff._hessian_color_preprocess(

0 commit comments

Comments
 (0)