Skip to content

Commit e39ecc6

Browse files
authored
Add option to specify preset colors to graph_greedy_color (#1060)
* Add option to specify preset colors to graph_greedy_color This commit adds a new argument to the graph_greed_color function, which enables a user to provide a callback function that specifies a color to use for particular nodes. * Fix type hint stub
1 parent ddb0cda commit e39ecc6

File tree

5 files changed

+207
-35
lines changed

5 files changed

+207
-35
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
---
2+
features:
3+
- |
4+
Added a new keyword argument, ``preset_color_fn``, to :func:`.graph_greedy_color`
5+
which is used to provide preset colors for specific nodes when computing the graph
6+
coloring. You can optionally pass a callable to that argument which will
7+
be passed node index from the graph and is either expected to return an
8+
integer color to use for that node, or `None` to indicate there is no
9+
preset color for that node. For example:
10+
11+
.. jupyter-execute::
12+
13+
import rustworkx as rx
14+
from rustworkx.visualization import mpl_draw
15+
16+
graph = rx.generators.generalized_petersen_graph(5, 2)
17+
18+
def preset_colors(node_index):
19+
if node_index == 0:
20+
return 3
21+
22+
coloring = rx.graph_greedy_color(graph, preset_color_fn=preset_colors)
23+
colors = [coloring[node] for node in graph.node_indices()]
24+
25+
layout = rx.shell_layout(graph, nlist=[[0, 1, 2, 3, 4],[6, 7, 8, 9, 5]])
26+
mpl_draw(graph, node_color=colors, pos=layout)
27+
- |
28+
Added a new function ``greedy_node_color_with_preset_colors`` to the
29+
rustworkx-core module ``coloring``. This new function is identical to the
30+
``rustworkx_core::coloring::greedy_node_color`` except it has a second
31+
preset parameter which is passed a callable which is used to provide preset
32+
colors for particular node ids.

rustworkx-core/src/coloring.rs

Lines changed: 103 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
// under the License.
1212

1313
use std::cmp::Reverse;
14+
use std::convert::Infallible;
1415
use std::hash::Hash;
1516

1617
use crate::dictmap::*;
@@ -96,6 +97,50 @@ where
9697
Some(colors)
9798
}
9899

100+
fn inner_greedy_node_color<G, F, E>(
101+
graph: G,
102+
mut preset_color_fn: F,
103+
) -> Result<DictMap<G::NodeId, usize>, E>
104+
where
105+
G: NodeCount + IntoNodeIdentifiers + IntoEdges,
106+
G::NodeId: Hash + Eq + Send + Sync,
107+
F: FnMut(G::NodeId) -> Result<Option<usize>, E>,
108+
{
109+
let mut colors: DictMap<G::NodeId, usize> = DictMap::with_capacity(graph.node_count());
110+
let mut node_vec: Vec<G::NodeId> = Vec::with_capacity(graph.node_count());
111+
let mut sort_map: HashMap<G::NodeId, usize> = HashMap::with_capacity(graph.node_count());
112+
for k in graph.node_identifiers() {
113+
if let Some(color) = preset_color_fn(k)? {
114+
colors.insert(k, color);
115+
continue;
116+
}
117+
node_vec.push(k);
118+
sort_map.insert(k, graph.edges(k).count());
119+
}
120+
node_vec.par_sort_by_key(|k| Reverse(sort_map.get(k)));
121+
122+
for node in node_vec {
123+
let mut neighbor_colors: HashSet<usize> = HashSet::new();
124+
for edge in graph.edges(node) {
125+
let target = edge.target();
126+
let existing_color = match colors.get(&target) {
127+
Some(color) => color,
128+
None => continue,
129+
};
130+
neighbor_colors.insert(*existing_color);
131+
}
132+
let mut current_color: usize = 0;
133+
loop {
134+
if !neighbor_colors.contains(&current_color) {
135+
break;
136+
}
137+
current_color += 1;
138+
}
139+
colors.insert(node, current_color);
140+
}
141+
Ok(colors)
142+
}
143+
99144
/// Color a graph using a greedy graph coloring algorithm.
100145
///
101146
/// This function uses a `largest-first` strategy as described in:
@@ -135,36 +180,65 @@ where
135180
G: NodeCount + IntoNodeIdentifiers + IntoEdges,
136181
G::NodeId: Hash + Eq + Send + Sync,
137182
{
138-
let mut colors: DictMap<G::NodeId, usize> = DictMap::with_capacity(graph.node_count());
139-
let mut node_vec: Vec<G::NodeId> = graph.node_identifiers().collect();
140-
141-
let mut sort_map: HashMap<G::NodeId, usize> = HashMap::with_capacity(graph.node_count());
142-
for k in node_vec.iter() {
143-
sort_map.insert(*k, graph.edges(*k).count());
144-
}
145-
node_vec.par_sort_by_key(|k| Reverse(sort_map.get(k)));
146-
147-
for node in node_vec {
148-
let mut neighbor_colors: HashSet<usize> = HashSet::new();
149-
for edge in graph.edges(node) {
150-
let target = edge.target();
151-
let existing_color = match colors.get(&target) {
152-
Some(color) => color,
153-
None => continue,
154-
};
155-
neighbor_colors.insert(*existing_color);
156-
}
157-
let mut current_color: usize = 0;
158-
loop {
159-
if !neighbor_colors.contains(&current_color) {
160-
break;
161-
}
162-
current_color += 1;
163-
}
164-
colors.insert(node, current_color);
165-
}
183+
inner_greedy_node_color(graph, |_| Ok::<Option<usize>, Infallible>(None)).unwrap()
184+
}
166185

167-
colors
186+
/// Color a graph using a greedy graph coloring algorithm with preset colors
187+
///
188+
/// This function uses a `largest-first` strategy as described in:
189+
///
190+
/// Adrian Kosowski, and Krzysztof Manuszewski, Classical Coloring of Graphs,
191+
/// Graph Colorings, 2-19, 2004. ISBN 0-8218-3458-4.
192+
///
193+
/// to color the nodes with higher degree first.
194+
///
195+
/// The coloring problem is NP-hard and this is a heuristic algorithm
196+
/// which may not return an optimal solution.
197+
///
198+
/// Arguments:
199+
///
200+
/// * `graph` - The graph object to run the algorithm on
201+
/// * `preset_color_fn` - A callback function that will recieve the node identifier
202+
/// for each node in the graph and is expected to return an `Option<usize>`
203+
/// (wrapped in a `Result`) that is `None` if the node has no preset and
204+
/// the usize represents the preset color.
205+
///
206+
/// # Example
207+
/// ```rust
208+
///
209+
/// use petgraph::graph::Graph;
210+
/// use petgraph::graph::NodeIndex;
211+
/// use petgraph::Undirected;
212+
/// use rustworkx_core::dictmap::*;
213+
/// use std::convert::Infallible;
214+
/// use rustworkx_core::coloring::greedy_node_color_with_preset_colors;
215+
///
216+
/// let preset_color_fn = |node_idx: NodeIndex| -> Result<Option<usize>, Infallible> {
217+
/// if node_idx.index() == 0 {
218+
/// Ok(Some(1))
219+
/// } else {
220+
/// Ok(None)
221+
/// }
222+
/// };
223+
///
224+
/// let g = Graph::<(), (), Undirected>::from_edges(&[(0, 1), (0, 2)]);
225+
/// let colors = greedy_node_color_with_preset_colors(&g, preset_color_fn).unwrap();
226+
/// let mut expected_colors = DictMap::new();
227+
/// expected_colors.insert(NodeIndex::new(0), 1);
228+
/// expected_colors.insert(NodeIndex::new(1), 0);
229+
/// expected_colors.insert(NodeIndex::new(2), 0);
230+
/// assert_eq!(colors, expected_colors);
231+
/// ```
232+
pub fn greedy_node_color_with_preset_colors<G, F, E>(
233+
graph: G,
234+
preset_color_fn: F,
235+
) -> Result<DictMap<G::NodeId, usize>, E>
236+
where
237+
G: NodeCount + IntoNodeIdentifiers + IntoEdges,
238+
G::NodeId: Hash + Eq + Send + Sync,
239+
F: FnMut(G::NodeId) -> Result<Option<usize>, E>,
240+
{
241+
inner_greedy_node_color(graph, preset_color_fn)
168242
}
169243

170244
/// Color edges of a graph using a greedy approach.

rustworkx/rustworkx.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,9 @@ def graph_katz_centrality(
135135

136136
# Coloring
137137

138-
def graph_greedy_color(graph: PyGraph, /) -> dict[int, int]: ...
138+
def graph_greedy_color(
139+
graph: PyGraph, /, preset_color_fn: Callable[[int], int | None] | None = ...
140+
) -> dict[int, int]: ...
139141
def graph_greedy_edge_color(graph: PyGraph, /) -> dict[int, int]: ...
140142
def graph_is_bipartite(graph: PyGraph) -> bool: ...
141143
def digraph_is_bipartite(graph: PyDiGraph) -> bool: ...

src/coloring.rs

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@
1010
// License for the specific language governing permissions and limitations
1111
// under the License.
1212

13-
use crate::{digraph, graph};
13+
use crate::{digraph, graph, NodeIndex};
1414

1515
use pyo3::prelude::*;
1616
use pyo3::types::PyDict;
1717
use pyo3::Python;
1818
use rustworkx_core::coloring::{
19-
greedy_edge_color, greedy_node_color, misra_gries_edge_color, two_color,
19+
greedy_edge_color, greedy_node_color, greedy_node_color_with_preset_colors,
20+
misra_gries_edge_color, two_color,
2021
};
2122

2223
/// Color a :class:`~.PyGraph` object using a greedy graph coloring algorithm.
@@ -30,6 +31,13 @@ use rustworkx_core::coloring::{
3031
/// may not return an optimal solution.
3132
///
3233
/// :param PyGraph: The input PyGraph object to color
34+
/// :param preset_color_fn: An optional callback function that is used to manually
35+
/// specify a color to use for particular nodes in the graph. If specified
36+
/// this takes a callable that will be passed a node index and is expected to
37+
/// either return an integer representing a color or ``None`` to indicate there
38+
/// is no preset. Note if you do use a callable there is no validation that
39+
/// the preset values are valid colors. You can generate an invalid coloring
40+
/// if you the specified function returned invalid colors for any nodes.
3341
///
3442
/// :returns: A dictionary where keys are node indices and the value is
3543
/// the color
@@ -52,9 +60,23 @@ use rustworkx_core::coloring::{
5260
/// .. [1] Adrian Kosowski, and Krzysztof Manuszewski, Classical Coloring of Graphs,
5361
/// Graph Colorings, 2-19, 2004. ISBN 0-8218-3458-4.
5462
#[pyfunction]
55-
#[pyo3(text_signature = "(graph, /)")]
56-
pub fn graph_greedy_color(py: Python, graph: &graph::PyGraph) -> PyResult<PyObject> {
57-
let colors = greedy_node_color(&graph.graph);
63+
#[pyo3(text_signature = "(graph, /, preset_color_fn=None)")]
64+
pub fn graph_greedy_color(
65+
py: Python,
66+
graph: &graph::PyGraph,
67+
preset_color_fn: Option<PyObject>,
68+
) -> PyResult<PyObject> {
69+
let colors = match preset_color_fn {
70+
Some(preset_color_fn) => {
71+
let callback = |node_idx: NodeIndex| -> PyResult<Option<usize>> {
72+
preset_color_fn
73+
.call1(py, (node_idx.index(),))
74+
.map(|x| x.extract(py).ok())
75+
};
76+
greedy_node_color_with_preset_colors(&graph.graph, callback)?
77+
}
78+
None => greedy_node_color(&graph.graph),
79+
};
5880
let out_dict = PyDict::new(py);
5981
for (node, color) in colors {
6082
out_dict.set_item(node.index(), color)?;

tests/graph/test_coloring.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,48 @@ def test_simple_graph_large_degree(self):
4545
res = rustworkx.graph_greedy_color(graph)
4646
self.assertEqual({0: 0, 1: 1, 2: 1}, res)
4747

48+
def test_simple_graph_with_preset(self):
49+
def preset(node_idx):
50+
if node_idx == 0:
51+
return 1
52+
return None
53+
54+
graph = rustworkx.PyGraph()
55+
node_a = graph.add_node("a")
56+
node_b = graph.add_node("b")
57+
graph.add_edge(node_a, node_b, 1)
58+
node_c = graph.add_node("c")
59+
graph.add_edge(node_a, node_c, 1)
60+
res = rustworkx.graph_greedy_color(graph, preset)
61+
self.assertEqual({0: 1, 1: 0, 2: 0}, res)
62+
63+
def test_simple_graph_large_degree_with_preset(self):
64+
def preset(node_idx):
65+
if node_idx == 0:
66+
return 1
67+
return None
68+
69+
graph = rustworkx.PyGraph()
70+
node_a = graph.add_node("a")
71+
node_b = graph.add_node("b")
72+
graph.add_edge(node_a, node_b, 1)
73+
node_c = graph.add_node("c")
74+
graph.add_edge(node_a, node_c, 1)
75+
graph.add_edge(node_a, node_c, 1)
76+
graph.add_edge(node_a, node_c, 1)
77+
graph.add_edge(node_a, node_c, 1)
78+
graph.add_edge(node_a, node_c, 1)
79+
res = rustworkx.graph_greedy_color(graph, preset)
80+
self.assertEqual({0: 1, 1: 0, 2: 0}, res)
81+
82+
def test_preset_raises_exception(self):
83+
def preset(node_idx):
84+
raise OverflowError("I am invalid")
85+
86+
graph = rustworkx.generators.path_graph(5)
87+
with self.assertRaises(OverflowError):
88+
rustworkx.graph_greedy_color(graph, preset)
89+
4890

4991
class TestGraphEdgeColoring(unittest.TestCase):
5092
def test_graph(self):

0 commit comments

Comments
 (0)