Skip to content

Commit 87c8016

Browse files
committed
Revisit the post-processing
1 parent 662df63 commit 87c8016

File tree

7 files changed

+589
-148
lines changed

7 files changed

+589
-148
lines changed

src/coloring.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ end
8181
"""
8282
star_coloring(
8383
g::AdjacencyGraph, vertices_in_order::AbstractVector, postprocessing::Bool;
84-
forced_colors::Union{AbstractVector,Nothing}=nothing
84+
postprocessing_minimizes::Symbol=:all_colors, forced_colors::Union{AbstractVector,Nothing}=nothing
8585
)
8686
8787
Compute a star coloring of all vertices in the adjacency graph `g` and return a tuple `(color, star_set)`, where
@@ -110,6 +110,7 @@ function star_coloring(
110110
g::AdjacencyGraph{T},
111111
vertices_in_order::AbstractVector{<:Integer},
112112
postprocessing::Bool;
113+
postprocessing_minimizes::Symbol=:all_colors,
113114
forced_colors::Union{AbstractVector{<:Integer},Nothing}=nothing,
114115
) where {T<:Integer}
115116
# Initialize data structures
@@ -168,7 +169,7 @@ function star_coloring(
168169
if postprocessing
169170
# Reuse the vector forbidden_colors to compute offsets during post-processing
170171
offsets = forbidden_colors
171-
postprocess!(color, star_set, g, offsets)
172+
postprocess!(color, star_set, g, offsets, postprocessing_minimizes)
172173
end
173174
return color, star_set
174175
end
@@ -250,7 +251,8 @@ struct StarSet{T}
250251
end
251252

252253
"""
253-
acyclic_coloring(g::AdjacencyGraph, vertices_in_order::AbstractVector, postprocessing::Bool)
254+
acyclic_coloring(g::AdjacencyGraph, vertices_in_order::AbstractVector, postprocessing::Bool;
255+
postprocessing_minimizes::Symbol=:all_colors)
254256
255257
Compute an acyclic coloring of all vertices in the adjacency graph `g` and return a tuple `(color, tree_set)`, where
256258
@@ -273,7 +275,10 @@ If `postprocessing=true`, some colors might be replaced with `0` (the "neutral"
273275
> [_New Acyclic and Star Coloring Algorithms with Application to Computing Hessians_](https://epubs.siam.org/doi/abs/10.1137/050639879), Gebremedhin et al. (2007), Algorithm 3.1
274276
"""
275277
function acyclic_coloring(
276-
g::AdjacencyGraph{T}, vertices_in_order::AbstractVector{<:Integer}, postprocessing::Bool
278+
g::AdjacencyGraph{T},
279+
vertices_in_order::AbstractVector{<:Integer},
280+
postprocessing::Bool;
281+
postprocessing_minimizes::Symbol=:all_colors,
277282
) where {T<:Integer}
278283
# Initialize data structures
279284
nv = nb_vertices(g)
@@ -345,7 +350,7 @@ function acyclic_coloring(
345350
if postprocessing
346351
# Reuse the vector forbidden_colors to compute offsets during post-processing
347352
offsets = forbidden_colors
348-
postprocess!(color, tree_set, g, offsets)
353+
postprocess!(color, tree_set, g, offsets, postprocessing_minimizes)
349354
end
350355
return color, tree_set
351356
end

src/graph.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ The adjacency graph of a symmetric matrix `A ∈ ℝ^{n × n}` is `G(A) = (V, E)
227227
struct AdjacencyGraph{T<:Integer,augmented_graph}
228228
S::SparsityPatternCSC{T}
229229
edge_to_index::Vector{T}
230+
original_size::Tuple{Int,Int}
230231
end
231232

232233
Base.eltype(::AdjacencyGraph{T}) where {T} = T
@@ -235,15 +236,20 @@ function AdjacencyGraph(
235236
S::SparsityPatternCSC{T},
236237
edge_to_index::Vector{T}=build_edge_to_index(S);
237238
augmented_graph::Bool=false,
239+
original_size::Tuple{Int,Int}=size(S),
238240
) where {T}
239-
return AdjacencyGraph{T,augmented_graph}(S, edge_to_index)
241+
return AdjacencyGraph{T,augmented_graph}(S, edge_to_index, original_size)
240242
end
241243

242-
function AdjacencyGraph(A::SparseMatrixCSC; augmented_graph::Bool=false)
244+
function AdjacencyGraph(
245+
A::SparseMatrixCSC; augmented_graph::Bool=false, original_size::Tuple{Int,Int}=size(A)
246+
)
243247
return AdjacencyGraph(SparsityPatternCSC(A); augmented_graph)
244248
end
245249

246-
function AdjacencyGraph(A::AbstractMatrix; augmented_graph::Bool=false)
250+
function AdjacencyGraph(
251+
A::AbstractMatrix; augmented_graph::Bool=false, original_size::Tuple{Int,Int}=size(A)
252+
)
247253
return AdjacencyGraph(SparseMatrixCSC(A); augmented_graph)
248254
end
249255

src/interface.jl

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,12 @@ It is passed as an argument to the main function [`coloring`](@ref).
6969
7070
# Constructors
7171
72-
GreedyColoringAlgorithm{decompression}(order=NaturalOrder(); postprocessing=false)
73-
GreedyColoringAlgorithm(order=NaturalOrder(); postprocessing=false, decompression=:direct)
72+
GreedyColoringAlgorithm{decompression}(order=NaturalOrder(); postprocessing=false, postprocessing_minimizes=:all_colors)
73+
GreedyColoringAlgorithm(order=NaturalOrder(); postprocessing=false, postprocessing_minimizes=:all_colors, decompression=:direct)
7474
7575
- `order::Union{AbstractOrder,Tuple}`: the order in which the columns or rows are colored, which can impact the number of colors. Can also be a tuple of different orders to try out, from which the best order (the one with the lowest total number of colors) will be used.
76-
- `postprocessing::Bool`: whether or not the coloring will be refined by assigning the neutral color `0` to some vertices.
76+
- `postprocessing::Bool`: whether or not the coloring will be refined by assigning the neutral color `0` to some vertices. This option does not affect row or column colorings.
77+
- `postprocessing_minimizes::Symbol`: which number of distinct colors is heuristically minimized by postprocessing, either `:all_colors`, `:row_colors` or `:column_colors`. This option only affects bidirectional colorings.
7778
- `decompression::Symbol`: either `:direct` or `:substitution`. Usually `:substitution` leads to fewer colors, at the cost of a more expensive coloring (and decompression). When `:substitution` is not applicable, it falls back on `:direct` decompression.
7879
7980
!!! warning
@@ -98,27 +99,34 @@ struct GreedyColoringAlgorithm{decompression,N,O<:NTuple{N,AbstractOrder}} <:
9899
ADTypes.AbstractColoringAlgorithm
99100
orders::O
100101
postprocessing::Bool
102+
postprocessing_minimizes::Symbol
101103

102104
function GreedyColoringAlgorithm{decompression}(
103105
order_or_orders::Union{AbstractOrder,Tuple}=NaturalOrder();
104106
postprocessing::Bool=false,
107+
postprocessing_minimizes::Symbol=:all_colors,
105108
) where {decompression}
106109
check_valid_algorithm(decompression)
107110
if order_or_orders isa AbstractOrder
108111
orders = (order_or_orders,)
109112
else
110113
orders = order_or_orders
111114
end
112-
return new{decompression,length(orders),typeof(orders)}(orders, postprocessing)
115+
return new{decompression,length(orders),typeof(orders)}(
116+
orders, postprocessing, postprocessing_minimizes
117+
)
113118
end
114119
end
115120

116121
function GreedyColoringAlgorithm(
117122
order_or_orders::Union{AbstractOrder,Tuple}=NaturalOrder();
118123
postprocessing::Bool=false,
119124
decompression::Symbol=:direct,
125+
postprocessing_minimizes::Symbol=:all_colors,
120126
)
121-
return GreedyColoringAlgorithm{decompression}(order_or_orders; postprocessing)
127+
return GreedyColoringAlgorithm{decompression}(
128+
order_or_orders; postprocessing, postprocessing_minimizes
129+
)
122130
end
123131

124132
## Coloring
@@ -279,7 +287,7 @@ function _coloring(
279287
symmetric_pattern::Bool;
280288
forced_colors::Union{AbstractVector{<:Integer},Nothing}=nothing,
281289
)
282-
ag = AdjacencyGraph(A; augmented_graph=false)
290+
ag = AdjacencyGraph(A; augmented_graph=false, original_size=size(A))
283291
color_and_star_set_by_order = map(algo.orders) do order
284292
vertices_in_order = vertices(ag, order)
285293
return star_coloring(ag, vertices_in_order, algo.postprocessing; forced_colors)
@@ -300,7 +308,7 @@ function _coloring(
300308
decompression_eltype::Type{R},
301309
symmetric_pattern::Bool,
302310
) where {R}
303-
ag = AdjacencyGraph(A; augmented_graph=false)
311+
ag = AdjacencyGraph(A; augmented_graph=false, original_size=size(A))
304312
color_and_tree_set_by_order = map(algo.orders) do order
305313
vertices_in_order = vertices(ag, order)
306314
return acyclic_coloring(ag, vertices_in_order, algo.postprocessing)
@@ -323,11 +331,18 @@ function _coloring(
323331
forced_colors::Union{AbstractVector{<:Integer},Nothing}=nothing,
324332
) where {R}
325333
A_and_Aᵀ, edge_to_index = bidirectional_pattern(A; symmetric_pattern)
326-
ag = AdjacencyGraph(A_and_Aᵀ, edge_to_index; augmented_graph=true)
334+
ag = AdjacencyGraph(
335+
A_and_Aᵀ, edge_to_index; augmented_graph=true, original_size=size(A)
336+
)
337+
postprocessing_minimizes = algo.postprocessing_minimizes
327338
outputs_by_order = map(algo.orders) do order
328339
vertices_in_order = vertices(ag, order)
329340
_color, _star_set = star_coloring(
330-
ag, vertices_in_order, algo.postprocessing; forced_colors
341+
ag,
342+
vertices_in_order,
343+
algo.postprocessing;
344+
postprocessing_minimizes,
345+
forced_colors,
331346
)
332347
(_row_color, _column_color, _symmetric_to_row, _symmetric_to_column) = remap_colors(
333348
eltype(ag), _color, maximum(_color), size(A)...
@@ -370,10 +385,15 @@ function _coloring(
370385
symmetric_pattern::Bool,
371386
) where {R}
372387
A_and_Aᵀ, edge_to_index = bidirectional_pattern(A; symmetric_pattern)
373-
ag = AdjacencyGraph(A_and_Aᵀ, edge_to_index; augmented_graph=true)
388+
ag = AdjacencyGraph(
389+
A_and_Aᵀ, edge_to_index; augmented_graph=true, original_size=size(A)
390+
)
391+
postprocessing_minimizes = algo.postprocessing_minimizes
374392
outputs_by_order = map(algo.orders) do order
375393
vertices_in_order = vertices(ag, order)
376-
_color, _tree_set = acyclic_coloring(ag, vertices_in_order, algo.postprocessing)
394+
_color, _tree_set = acyclic_coloring(
395+
ag, vertices_in_order, algo.postprocessing; postprocessing_minimizes
396+
)
377397
(_row_color, _column_color, _symmetric_to_row, _symmetric_to_column) = remap_colors(
378398
eltype(ag), _color, maximum(_color), size(A)...
379399
)

0 commit comments

Comments
 (0)