|
4 | 4 | Coloring algorithm which always returns the same precomputed vector of colors. |
5 | 5 | Useful when the optimal coloring of a matrix can be determined a priori due to its specific structure (e.g. banded). |
6 | 6 |
|
7 | | -It is passed as an argument to the main function [`coloring`](@ref), but will only work if the associated `problem` has `:nonsymmetric` structure. |
8 | | -Indeed, for symmetric coloring problems, we need more than just the vector of colors to allow fast decompression. |
| 7 | +It is passed as an argument to the main function [`coloring`](@ref), but will only work if the associated `problem` has a `:column` or `:row` partition. |
9 | 8 |
|
10 | 9 | # Constructors |
11 | 10 |
|
12 | 11 | ConstantColoringAlgorithm{partition}(matrix_template, color) |
13 | | - ConstantColoringAlgorithm(matrix_template, color; partition=:column) |
| 12 | + ConstantColoringAlgorithm{partition,structure}(matrix_template, color) |
| 13 | + ConstantColoringAlgorithm( |
| 14 | + matrix_template, color; |
| 15 | + structure=:nonsymmetric, partition=:column |
| 16 | + ) |
14 | 17 |
|
15 | 18 | - `partition::Symbol`: either `:row` or `:column`. |
| 19 | +- `structure::Symbol`: either `:nonsymmetric` or `:symmetric`. |
16 | 20 | - `matrix_template::AbstractMatrix`: matrix for which the vector of colors was precomputed (the algorithm will only accept matrices of the exact same size). |
17 | 21 | - `color::Vector{<:Integer}`: vector of integer colors, one for each row or column (depending on `partition`). |
18 | 22 |
|
19 | 23 | !!! warning |
20 | | - The second constructor (based on keyword arguments) is type-unstable. |
| 24 | + The constructor based on keyword arguments is type-unstable if these arguments are not compile-time constants. |
21 | 25 |
|
22 | 26 | We do not necessarily verify consistency between the matrix template and the vector of colors, this is the responsibility of the user. |
23 | 27 |
|
@@ -63,71 +67,68 @@ julia> column_colors(result) |
63 | 67 |
|
64 | 68 | - [`ADTypes.column_coloring`](@extref ADTypes.column_coloring) |
65 | 69 | - [`ADTypes.row_coloring`](@extref ADTypes.row_coloring) |
| 70 | +- [`ADTypes.symmetric_coloring`](@extref ADTypes.symmetric_coloring) |
66 | 71 | """ |
67 | | -struct ConstantColoringAlgorithm{ |
68 | | - partition, |
69 | | - M<:AbstractMatrix, |
70 | | - T<:Integer, |
71 | | - R<:AbstractColoringResult{:nonsymmetric,partition,:direct}, |
72 | | -} <: ADTypes.AbstractColoringAlgorithm |
| 72 | +struct ConstantColoringAlgorithm{partition,structure,M<:AbstractMatrix,T<:Integer} <: |
| 73 | + ADTypes.AbstractColoringAlgorithm |
73 | 74 | matrix_template::M |
74 | 75 | color::Vector{T} |
75 | | - result::R |
76 | | -end |
77 | 76 |
|
78 | | -function ConstantColoringAlgorithm{:column}( |
79 | | - matrix_template::AbstractMatrix, color::Vector{<:Integer} |
80 | | -) |
81 | | - bg = BipartiteGraph(matrix_template) |
82 | | - result = ColumnColoringResult(matrix_template, bg, color) |
83 | | - T, M, R = eltype(bg), typeof(matrix_template), typeof(result) |
84 | | - return ConstantColoringAlgorithm{:column,M,T,R}(matrix_template, color, result) |
| 77 | + function ConstantColoringAlgorithm{partition,structure}( |
| 78 | + matrix_template::AbstractMatrix, color::Vector{<:Integer} |
| 79 | + ) where {partition,structure} |
| 80 | + check_valid_problem(structure, partition) |
| 81 | + return new{partition,structure,typeof(matrix_template),eltype(color)}( |
| 82 | + matrix_template, color |
| 83 | + ) |
| 84 | + end |
85 | 85 | end |
86 | 86 |
|
87 | | -function ConstantColoringAlgorithm{:row}( |
| 87 | +function ConstantColoringAlgorithm{partition}( |
88 | 88 | matrix_template::AbstractMatrix, color::Vector{<:Integer} |
89 | | -) |
90 | | - bg = BipartiteGraph(matrix_template) |
91 | | - result = RowColoringResult(matrix_template, bg, color) |
92 | | - T, M, R = eltype(bg), typeof(matrix_template), typeof(result) |
93 | | - return ConstantColoringAlgorithm{:row,M,T,R}(matrix_template, color, result) |
| 89 | +) where {partition} |
| 90 | + return ConstantColoringAlgorithm{partition,:nonsymmetric}(matrix_template, color) |
94 | 91 | end |
95 | 92 |
|
96 | 93 | function ConstantColoringAlgorithm( |
97 | | - matrix_template::AbstractMatrix, color::Vector{<:Integer}; partition::Symbol=:column |
| 94 | + matrix_template::AbstractMatrix, |
| 95 | + color::Vector{<:Integer}; |
| 96 | + structure::Symbol=:nonsymmetric, |
| 97 | + partition::Symbol=:column, |
98 | 98 | ) |
99 | | - return ConstantColoringAlgorithm{partition}(matrix_template, color) |
| 99 | + return ConstantColoringAlgorithm{partition,structure}(matrix_template, color) |
100 | 100 | end |
101 | 101 |
|
102 | | -function coloring( |
103 | | - A::AbstractMatrix, |
104 | | - ::ColoringProblem{:nonsymmetric,partition}, |
105 | | - algo::ConstantColoringAlgorithm{partition}; |
106 | | - decompression_eltype::Type=Float64, |
107 | | - symmetric_pattern::Bool=false, |
108 | | -) where {partition} |
109 | | - (; matrix_template, result) = algo |
| 102 | +function check_template(algo::ConstantColoringAlgorithm, A::AbstractMatrix) |
| 103 | + (; matrix_template) = algo |
110 | 104 | if size(A) != size(matrix_template) |
111 | 105 | throw( |
112 | 106 | DimensionMismatch( |
113 | 107 | "`ConstantColoringAlgorithm` expected matrix of size $(size(matrix_template)) but got matrix of size $(size(A))", |
114 | 108 | ), |
115 | 109 | ) |
116 | | - else |
117 | | - return result |
118 | 110 | end |
119 | 111 | end |
120 | 112 |
|
121 | 113 | function ADTypes.column_coloring( |
122 | | - A::AbstractMatrix, algo::ConstantColoringAlgorithm{:column} |
| 114 | + A::AbstractMatrix, algo::ConstantColoringAlgorithm{:column,:nonsymmetric} |
| 115 | +) |
| 116 | + check_template(algo, A) |
| 117 | + return algo.color |
| 118 | +end |
| 119 | + |
| 120 | +function ADTypes.row_coloring( |
| 121 | + A::AbstractMatrix, algo::ConstantColoringAlgorithm{:row,:nonsymmetric} |
123 | 122 | ) |
124 | | - problem = ColoringProblem{:nonsymmetric,:column}() |
125 | | - result = coloring(A, problem, algo) |
126 | | - return column_colors(result) |
| 123 | + check_template(algo, A) |
| 124 | + return algo.color |
127 | 125 | end |
128 | 126 |
|
129 | | -function ADTypes.row_coloring(A::AbstractMatrix, algo::ConstantColoringAlgorithm) |
130 | | - problem = ColoringProblem{:nonsymmetric,:row}() |
131 | | - result = coloring(A, problem, algo) |
132 | | - return row_colors(result) |
| 127 | +function ADTypes.symmetric_coloring( |
| 128 | + A::AbstractMatrix, algo::ConstantColoringAlgorithm{:column,:symmetric} |
| 129 | +) |
| 130 | + check_template(algo, A) |
| 131 | + return algo.color |
133 | 132 | end |
| 133 | + |
| 134 | +# TODO: handle bidirectional once https://github.com/SciML/ADTypes.jl/issues/69 is done |
0 commit comments