Skip to content

Commit 8b5d38b

Browse files
Add PyDiGraph.neighbors_undirected (#1254)
* Add `PyDiGraph.neighbors_undirected` * Add reno * add stub * review comments - additional test comparing w/ to_undirected - example in docstring * Apply suggestions from code review --------- Co-authored-by: Ivan Carvalho <[email protected]>
1 parent b6f0ff5 commit 8b5d38b

File tree

4 files changed

+61
-0
lines changed

4 files changed

+61
-0
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
features:
3+
- |
4+
Added a new method :meth:`~rustworkx.PyDiGraph.neighbors_undirected` to
5+
obtain the neighbors of a node in a directed graph, irrespective of the
6+
edge directionality.

rustworkx/rustworkx.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,6 +1391,7 @@ class PyDiGraph(Generic[_S, _T]):
13911391
def make_symmetric(self, edge_payload_fn: Callable[[_T], _T] | None = ...) -> None: ...
13921392
def merge_nodes(self, u: int, v: int, /) -> None: ...
13931393
def neighbors(self, node: int, /) -> NodeIndices: ...
1394+
def neighbors_undirected(self, node: int, /) -> NodeIndices: ...
13941395
def node_indexes(self) -> NodeIndices: ...
13951396
def node_indices(self) -> NodeIndices: ...
13961397
def nodes(self) -> list[_S]: ...

src/digraph.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1720,6 +1720,38 @@ impl PyDiGraph {
17201720
}
17211721
}
17221722

1723+
/// Get the direction-agnostic neighbors (i.e. successors and predecessors) of a node.
1724+
///
1725+
/// This is functionally equivalent to converting the directed graph to an undirected
1726+
/// graph, and calling ``neighbors`` thereon. For example::
1727+
///
1728+
/// import rustworkx
1729+
///
1730+
/// dag = rustworkx.generators.directed_cycle_graph(num_nodes=10, bidirectional=False)
1731+
///
1732+
/// node = 3
1733+
/// neighbors = dag.neighbors_undirected(node)
1734+
/// same_neighbors = dag.to_undirected().neighbors(node)
1735+
///
1736+
/// assert sorted(neighbors) == sorted(same_neighbors)
1737+
///
1738+
/// :param int node: The index of the node to get the neighbors of
1739+
///
1740+
/// :returns: A list of the neighbor node indices
1741+
/// :rtype: NodeIndices
1742+
#[pyo3(text_signature = "(self, node, /)")]
1743+
pub fn neighbors_undirected(&self, node: usize) -> NodeIndices {
1744+
NodeIndices {
1745+
nodes: self
1746+
.graph
1747+
.neighbors_undirected(NodeIndex::new(node))
1748+
.map(|node| node.index())
1749+
.collect::<HashSet<usize>>()
1750+
.drain()
1751+
.collect(),
1752+
}
1753+
}
1754+
17231755
/// Get the successor indices of a node.
17241756
///
17251757
/// This will return a list of the node indicies for the succesors of

tests/digraph/test_neighbors.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import unittest
1414

1515
import rustworkx
16+
import rustworkx.generators
1617

1718

1819
class TestAdj(unittest.TestCase):
@@ -57,3 +58,24 @@ def test_no_neighbor(self):
5758
dag = rustworkx.PyDAG()
5859
node_a = dag.add_node("a")
5960
self.assertEqual([], dag.neighbors(node_a))
61+
62+
def test_undirected_neighbors(self):
63+
dag = rustworkx.PyDAG()
64+
node_a = dag.add_node("a")
65+
node_b = dag.add_child(node_a, "b", {"a": 1})
66+
67+
directed = dag.neighbors(node_b)
68+
self.assertEqual([], directed)
69+
70+
undirected = dag.neighbors_undirected(node_b)
71+
self.assertEqual([node_a], undirected)
72+
73+
def test_undirected_neighbors_cycle(self):
74+
num_nodes = 10
75+
dag = rustworkx.generators.directed_cycle_graph(num_nodes, bidirectional=False)
76+
undirected_dag = dag.to_undirected()
77+
78+
for node in dag.node_indices():
79+
undirected_neighbors = dag.neighbors_undirected(node)
80+
expected_neighbors = undirected_dag.neighbors(node)
81+
self.assertEqual(sorted(undirected_neighbors), sorted(expected_neighbors))

0 commit comments

Comments
 (0)