Skip to content

Commit 4d0b3b4

Browse files
committed
Fix format
1 parent 2be2a66 commit 4d0b3b4

File tree

6 files changed

+77
-40
lines changed

6 files changed

+77
-40
lines changed

src/ArrayDiff.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,18 @@ import SparseMatrixColorings
1717
1818
Fork of `MOI.Nonlinear.SparseReverseMode` to add array support.
1919
"""
20-
struct Mode{C<:SparseMatrixColorings.GreedyColoringAlgorithm} <: MOI.Nonlinear.AbstractAutomaticDifferentiation
20+
struct Mode{C<:SparseMatrixColorings.GreedyColoringAlgorithm} <:
21+
MOI.Nonlinear.AbstractAutomaticDifferentiation
2122
coloring_algorithm::C
2223
end
2324

24-
Mode() = Mode(SparseMatrixColorings.GreedyColoringAlgorithm(; decompression=:substitution))
25+
function Mode()
26+
return Mode(
27+
SparseMatrixColorings.GreedyColoringAlgorithm(;
28+
decompression = :substitution,
29+
),
30+
)
31+
end
2532

2633
function MOI.Nonlinear.Evaluator(
2734
model::MOI.Nonlinear.Model,

src/coloring_compat.jl

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ function _hessian_color_preprocess(
4848
end
4949
local_indices = sort!(collect(seen_idx))
5050
empty!(seen_idx)
51-
51+
5252
# Handle empty case (no edges in Hessian)
5353
if isempty(local_indices)
5454
# Return empty structure - no variables to color
@@ -57,24 +57,30 @@ function _hessian_color_preprocess(
5757
# For the result, we'll create a minimal valid structure with a diagonal element
5858
# Note: This case should rarely occur in practice
5959
S = SparseArrays.spdiagm(0 => [true])
60-
problem = SparseMatrixColorings.ColoringProblem(; structure=:symmetric, partition=:column)
60+
problem = SparseMatrixColorings.ColoringProblem(;
61+
structure = :symmetric,
62+
partition = :column,
63+
)
6164
tree_result = SparseMatrixColorings.coloring(S, problem, algo)
6265
result = ColoringResult(tree_result, Int[])
6366
return I, J, result
6467
end
65-
68+
6669
# Also handle case where we have vertices but no edges (diagonal-only Hessian)
6770
if isempty(I)
6871
# Create identity matrix pattern (diagonal only)
6972
n = length(local_indices)
7073
S = SparseArrays.spdiagm(0 => trues(n))
71-
problem = SparseMatrixColorings.ColoringProblem(; structure=:symmetric, partition=:column)
74+
problem = SparseMatrixColorings.ColoringProblem(;
75+
structure = :symmetric,
76+
partition = :column,
77+
)
7278
tree_result = SparseMatrixColorings.coloring(S, problem, algo)
7379
result = ColoringResult(tree_result, local_indices)
7480
# I and J are already empty, which is correct for no off-diagonal elements
7581
return I, J, result
7682
end
77-
83+
7884
global_to_local_idx = seen_idx.nzidx # steal for storage
7985
for k in eachindex(local_indices)
8086
global_to_local_idx[local_indices[k]] = k
@@ -84,7 +90,7 @@ function _hessian_color_preprocess(
8490
I[k] = global_to_local_idx[I[k]]
8591
J[k] = global_to_local_idx[J[k]]
8692
end
87-
93+
8894
# Create sparsity pattern matrix
8995
n = length(local_indices)
9096
S = SparseArrays.spzeros(Bool, n, n)
@@ -93,39 +99,42 @@ function _hessian_color_preprocess(
9399
S[i, j] = true
94100
S[j, i] = true # symmetric
95101
end
96-
102+
97103
# Perform coloring using SparseMatrixColorings
98-
problem = SparseMatrixColorings.ColoringProblem(; structure=:symmetric, partition=:column)
104+
problem = SparseMatrixColorings.ColoringProblem(;
105+
structure = :symmetric,
106+
partition = :column,
107+
)
99108
tree_result = SparseMatrixColorings.coloring(S, problem, algo)
100-
109+
101110
# Reconstruct I and J from the tree structure (matching original _indirect_recover_structure)
102111
# First add all diagonal elements
103112
N = length(local_indices)
104-
113+
105114
# Count off-diagonal elements from tree structure
106115
(; reverse_bfs_orders, tree_edge_indices, nt) = tree_result
107116
nnz_offdiag = 0
108117
for tree_idx in 1:nt
109118
first = tree_edge_indices[tree_idx]
110-
last = tree_edge_indices[tree_idx + 1] - 1
119+
last = tree_edge_indices[tree_idx+1] - 1
111120
nnz_offdiag += (last - first + 1)
112121
end
113-
122+
114123
I_new = Vector{Int}(undef, N + nnz_offdiag)
115124
J_new = Vector{Int}(undef, N + nnz_offdiag)
116125
k = 0
117-
126+
118127
# Add all diagonal elements
119128
for i in 1:N
120129
k += 1
121130
I_new[k] = local_indices[i]
122131
J_new[k] = local_indices[i]
123132
end
124-
133+
125134
# Then add off-diagonal elements from the tree structure
126135
for tree_idx in 1:nt
127136
first = tree_edge_indices[tree_idx]
128-
last = tree_edge_indices[tree_idx + 1] - 1
137+
last = tree_edge_indices[tree_idx+1] - 1
129138
for pos in first:last
130139
(i_local, j_local) = reverse_bfs_orders[pos]
131140
# Convert from local to global indices and normalize (lower triangle)
@@ -139,9 +148,9 @@ function _hessian_color_preprocess(
139148
J_new[k] = j_global
140149
end
141150
end
142-
151+
143152
@assert k == length(I_new)
144-
153+
145154
# Wrap result with local_indices
146155
result = ColoringResult(tree_result, local_indices)
147156
return I_new, J_new, result
@@ -203,7 +212,7 @@ function _recover_from_matmat!(
203212
# V contains N diagonal elements + nnz_offdiag off-diagonal elements
204213
nnz_offdiag = length(V) - N
205214
@assert length(stored_values) >= N
206-
215+
207216
# Recover diagonal elements
208217
k = 0
209218
for i in 1:N
@@ -214,23 +223,23 @@ function _recover_from_matmat!(
214223
V[k] = zero(T)
215224
end
216225
end
217-
226+
218227
# Recover off-diagonal elements using tree structure
219228
(; reverse_bfs_orders, tree_edge_indices, nt) = tree_result
220229
fill!(stored_values, zero(T))
221-
230+
222231
for tree_idx in 1:nt
223232
first = tree_edge_indices[tree_idx]
224-
last = tree_edge_indices[tree_idx + 1] - 1
225-
233+
last = tree_edge_indices[tree_idx+1] - 1
234+
226235
# Reset stored_values for vertices in this tree
227236
for pos in first:last
228237
(vertex, _) = reverse_bfs_orders[pos]
229238
stored_values[vertex] = zero(T)
230239
end
231240
(_, root) = reverse_bfs_orders[last]
232241
stored_values[root] = zero(T)
233-
242+
234243
# Recover edge values
235244
for pos in first:last
236245
(i, j) = reverse_bfs_orders[pos]
@@ -244,7 +253,7 @@ function _recover_from_matmat!(
244253
V[k] = value
245254
end
246255
end
247-
256+
248257
@assert k == length(V)
249258
return
250259
end

src/forward_over_reverse.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,12 @@ function _eval_hessian(
6565
)
6666
end
6767
# TODO(odow): consider reverting to a view.
68-
output_slice = _UnsafeVectorView{Float64}(nzcount, length(ex.hess_I), pointer(H))::_UnsafeVectorView{Float64}
69-
_recover_from_matmat!(
70-
output_slice,
71-
ex.seed_matrix,
72-
ex.rinfo,
73-
d.output_ϵ,
74-
)
68+
output_slice = _UnsafeVectorView{Float64}(
69+
nzcount,
70+
length(ex.hess_I),
71+
pointer(H),
72+
)::_UnsafeVectorView{Float64}
73+
_recover_from_matmat!(output_slice, ex.seed_matrix, ex.rinfo, d.output_ϵ)
7574
for i in 1:length(output_slice)
7675
output_slice[i] *= scale
7776
end

src/mathoptinterface_api.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ function MOI.features_available(d::NLPEvaluator)
1919
return [:Grad, :Jac, :JacVec, :Hess, :HessVec]
2020
end
2121

22-
function MOI.initialize(d::NLPEvaluator{R}, requested_features::Vector{Symbol}) where {R}
22+
function MOI.initialize(
23+
d::NLPEvaluator{R},
24+
requested_features::Vector{Symbol},
25+
) where {R}
2326
# Check that we support the features requested by the user.
2427
available_features = MOI.features_available(d)
2528
for feature in requested_features

src/types.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,10 @@ struct _FunctionStorage{R<:SparseMatrixColorings.AbstractColoringResult}
7878
expr::_SubexpressionStorage,
7979
num_variables,
8080
coloring_storage::IndexedSet,
81-
coloring_algorithm::Union{Nothing,SparseMatrixColorings.GreedyColoringAlgorithm},
81+
coloring_algorithm::Union{
82+
Nothing,
83+
SparseMatrixColorings.GreedyColoringAlgorithm,
84+
},
8285
subexpressions::Vector{_SubexpressionStorage},
8386
dependent_subexpressions,
8487
subexpression_edgelist,
@@ -147,7 +150,10 @@ interface.
147150
!!! warning
148151
Before using, you must initialize the evaluator using `MOI.initialize`.
149152
"""
150-
mutable struct NLPEvaluator{R,C<:SparseMatrixColorings.GreedyColoringAlgorithm} <: MOI.AbstractNLPEvaluator
153+
mutable struct NLPEvaluator{
154+
R,
155+
C<:SparseMatrixColorings.GreedyColoringAlgorithm,
156+
} <: MOI.AbstractNLPEvaluator
151157
data::Nonlinear.Model
152158
ordered_variables::Vector{MOI.VariableIndex}
153159
coloring_algorithm::C
@@ -187,9 +193,14 @@ mutable struct NLPEvaluator{R,C<:SparseMatrixColorings.GreedyColoringAlgorithm}
187193
function NLPEvaluator(
188194
data::Nonlinear.Model,
189195
ordered_variables::Vector{MOI.VariableIndex},
190-
coloring_algorithm::SparseMatrixColorings.GreedyColoringAlgorithm = SparseMatrixColorings.GreedyColoringAlgorithm(; decompression=:substitution),
196+
coloring_algorithm::SparseMatrixColorings.GreedyColoringAlgorithm = SparseMatrixColorings.GreedyColoringAlgorithm(;
197+
decompression = :substitution,
198+
),
191199
)
192-
problem = SparseMatrixColorings.ColoringProblem(; structure=:symmetric, partition=:column)
200+
problem = SparseMatrixColorings.ColoringProblem(;
201+
structure = :symmetric,
202+
partition = :column,
203+
)
193204
C = typeof(coloring_algorithm)
194205
R = Base.promote_op(
195206
SparseMatrixColorings.coloring,

test/ReverseAD.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -480,8 +480,16 @@ end
480480

481481
function test_coloring_end_to_end_hessian_coloring_and_recovery()
482482
# Test the new coloring API through the compatibility layer
483-
coloring_algorithm = ArrayDiff.SparseMatrixColorings.GreedyColoringAlgorithm(; decompression=:substitution)
484-
I, J, rinfo = ArrayDiff._hessian_color_preprocess(Set([(1, 2)]), 2, coloring_algorithm, ArrayDiff.IndexedSet(0))
483+
coloring_algorithm =
484+
ArrayDiff.SparseMatrixColorings.GreedyColoringAlgorithm(;
485+
decompression = :substitution,
486+
)
487+
I, J, rinfo = ArrayDiff._hessian_color_preprocess(
488+
Set([(1, 2)]),
489+
2,
490+
coloring_algorithm,
491+
ArrayDiff.IndexedSet(0),
492+
)
485493
R = ArrayDiff._seed_matrix(rinfo)
486494
ArrayDiff._prepare_seed_matrix!(R, rinfo)
487495
@test I == [1, 2, 2]

0 commit comments

Comments
 (0)