Skip to content

Commit a9460dc

Browse files
Add in and out edge indices functions (#1369)
* Add method to retrieve incoming edge indices for a node * Add type annotations in stub file * Add tests for in_edge_indices * Add method to get outgoing edge indices for a node * Add in_edge_indices and out_edge_indices methods in PyGraph * update stub file * Add tests for out_edge_indices and in_edge_indices * Add release notes for in_edge_indices and out_edge_indices functions * Change releasenote name
1 parent df38b0a commit a9460dc

File tree

6 files changed

+173
-0
lines changed

6 files changed

+173
-0
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
features:
2+
- |
3+
Added two new functions, :func:`~rustworkx.in_edge_indices` and
4+
:func:`~rustworkx.out_edge_indices`, to the ``rustworkx.PyDiGraph`` class. These functions return the indices of all incoming and outgoing edges for a given node in a directed graph, respectively.
5+
- |
6+
Added two new functions, :func:`~rustworkx.in_edge_indices` and
7+
:func:`~rustworkx.out_edge_indices`, to the ``rustworkx.PyGraph`` class. As ``PyGraph`` is an undirected graph, both functions return the indices of all edges connected to the given node. This maintains API consistency with the directed graph implementation.

rustworkx/rustworkx.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,13 +1262,15 @@ class PyGraph(Generic[_S, _T]):
12621262
def in_edges(self, node: int, /) -> WeightedEdgeList[_T]: ...
12631263
def incident_edge_index_map(self, node: int, /) -> EdgeIndexMap: ...
12641264
def incident_edges(self, node: int, /) -> EdgeIndices: ...
1265+
def in_edge_indices(self, node: int, /) -> EdgeIndices: ...
12651266
def neighbors(self, node: int, /) -> NodeIndices: ...
12661267
def node_indexes(self) -> NodeIndices: ...
12671268
def node_indices(self) -> NodeIndices: ...
12681269
def nodes(self) -> list[_S]: ...
12691270
def num_edges(self) -> int: ...
12701271
def num_nodes(self) -> int: ...
12711272
def out_edges(self, node: int, /) -> WeightedEdgeList[_T]: ...
1273+
def out_edge_indices(self, node: int, /) -> EdgeIndices: ...
12721274
@staticmethod
12731275
def read_edge_list(
12741276
path: str,
@@ -1428,6 +1430,7 @@ class PyDiGraph(Generic[_S, _T]):
14281430
def in_edges(self, node: int, /) -> WeightedEdgeList[_T]: ...
14291431
def incident_edge_index_map(self, node: int, /, all_edges: bool = ...) -> EdgeIndexMap: ...
14301432
def incident_edges(self, node: int, /, all_edges: bool = ...) -> EdgeIndices: ...
1433+
def in_edge_indices(self, node: int, /) -> EdgeIndices: ...
14311434
def insert_node_on_in_edges(self, node: int, ref_node: int, /) -> None: ...
14321435
def insert_node_on_in_edges_multiple(self, node: int, ref_nodes: Sequence[int], /) -> None: ...
14331436
def insert_node_on_out_edges(self, node: int, ref_node: int, /) -> None: ...
@@ -1444,6 +1447,7 @@ class PyDiGraph(Generic[_S, _T]):
14441447
def num_nodes(self) -> int: ...
14451448
def out_degree(self, node: int, /) -> int: ...
14461449
def out_edges(self, node: int, /) -> WeightedEdgeList[_T]: ...
1450+
def out_edge_indices(self, node: int, /) -> EdgeIndices: ...
14471451
def predecessor_indices(self, node: int, /) -> NodeIndices: ...
14481452
def predecessors(self, node: int, /) -> list[_S]: ...
14491453
@staticmethod

src/digraph.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1842,6 +1842,54 @@ impl PyDiGraph {
18421842
}
18431843
}
18441844

1845+
/// Return the list of incoming edge indices to a provided node
1846+
///
1847+
/// This method will return the incoming edges of the provided
1848+
/// ``node``.
1849+
///
1850+
/// :param int node: The node index to get incoming edges from. If
1851+
/// this node index is not present in the graph this method will
1852+
/// return an empty list and not error.
1853+
///
1854+
/// :returns: A list of the incoming edge indices to a node in the graph
1855+
/// :rtype: EdgeIndices
1856+
#[pyo3(text_signature = "(self, node, /)")]
1857+
pub fn in_edge_indices(&self, node: usize) -> EdgeIndices {
1858+
let node_index = NodeIndex::new(node);
1859+
let dir = petgraph::Direction::Incoming;
1860+
EdgeIndices {
1861+
edges: self
1862+
.graph
1863+
.edges_directed(node_index, dir)
1864+
.map(|e| e.id().index())
1865+
.collect(),
1866+
}
1867+
}
1868+
1869+
/// Return the list of outgoing edge indices from a provided node
1870+
///
1871+
/// This method will return the outgoing edges of the provided
1872+
/// ``node``.
1873+
///
1874+
/// :param int node: The node index to get outgoing edges from. If
1875+
/// this node index is not present in the graph this method will
1876+
/// return an empty list and not error.
1877+
///
1878+
/// :returns: A list of the outgoing edge indices from a node in the graph
1879+
/// :rtype: EdgeIndices
1880+
#[pyo3(text_signature = "(self, node, /)")]
1881+
pub fn out_edge_indices(&self, node: usize) -> EdgeIndices {
1882+
let node_index = NodeIndex::new(node);
1883+
let dir = petgraph::Direction::Outgoing;
1884+
EdgeIndices {
1885+
edges: self
1886+
.graph
1887+
.edges_directed(node_index, dir)
1888+
.map(|e| e.id().index())
1889+
.collect(),
1890+
}
1891+
}
1892+
18451893
/// Return the index map of edges incident to a provided node
18461894
///
18471895
/// By default this method will only return the outgoing edges of

src/graph.rs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,52 @@ impl PyGraph {
552552
}
553553
}
554554

555+
/// Return the list of edge indices incident to a provided node.
556+
///
557+
/// This method returns the indices of all edges connected to the provided
558+
/// ``node``. In undirected graphs, all edges connected to the node are
559+
/// returned as there is no distinction between incoming and outgoing edges.
560+
///
561+
/// :param int node: The node index to get incident edges from. If
562+
/// this node index is not present in the graph this method will
563+
/// return an empty list and not error.
564+
///
565+
/// :returns: A list of the edge indices incident to the node
566+
/// :rtype: EdgeIndices
567+
#[pyo3(text_signature = "(self, node, /)")]
568+
pub fn in_edge_indices(&self, node: usize) -> EdgeIndices {
569+
EdgeIndices {
570+
edges: self
571+
.graph
572+
.edges(NodeIndex::new(node))
573+
.map(|e| e.id().index())
574+
.collect(),
575+
}
576+
}
577+
578+
/// Return the list of edge indices incident to a provided node.
579+
///
580+
/// This method returns the indices of all edges connected to the provided
581+
/// ``node``. In undirected graphs, all edges connected to the node are
582+
/// returned as there is no distinction between incoming and outgoing edges.
583+
///
584+
/// :param int node: The node index to get incident edges from. If
585+
/// this node index is not present in the graph this method will
586+
/// return an empty list and not error.
587+
///
588+
/// :returns: A list of the edge indices incident to the node
589+
/// :rtype: EdgeIndices
590+
#[pyo3(text_signature = "(self, node, /)")]
591+
pub fn out_edge_indices(&self, node: usize) -> EdgeIndices {
592+
EdgeIndices {
593+
edges: self
594+
.graph
595+
.edges(NodeIndex::new(node))
596+
.map(|e| e.id().index())
597+
.collect(),
598+
}
599+
}
600+
555601
/// Return the index map of edges incident to a provided node
556602
///
557603
/// :param int node: The node index to get incident edges from. If

tests/digraph/test_edges.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,40 @@ def test_incident_edges_all_edges(self):
813813
res = graph.incident_edges(node_d, all_edges=True)
814814
self.assertEqual([2, 1], res)
815815

816+
def test_in_edge_indices(self):
817+
graph = rustworkx.PyDiGraph()
818+
node_a = graph.add_node(0)
819+
node_b = graph.add_node(1)
820+
node_c = graph.add_node("c")
821+
node_d = graph.add_node("d")
822+
edge_ac = graph.add_edge(node_a, node_c, "edge a")
823+
graph.add_edge(node_b, node_d, "edge b")
824+
edge_dc = graph.add_edge(node_d, node_c, "edge c")
825+
res = graph.in_edge_indices(node_c)
826+
self.assertEqual([edge_dc, edge_ac], res)
827+
828+
def test_in_edge_indices_invalid_node(self):
829+
graph = rustworkx.PyDiGraph()
830+
res = graph.in_edge_indices(0)
831+
self.assertEqual([], res)
832+
833+
def test_out_edge_indices(self):
834+
graph = rustworkx.PyDiGraph()
835+
node_a = graph.add_node(0)
836+
node_b = graph.add_node(1)
837+
node_c = graph.add_node("c")
838+
node_d = graph.add_node("d")
839+
edge_ab = graph.add_edge(node_a, node_b, "edge a")
840+
edge_ac = graph.add_edge(node_a, node_c, "edge b")
841+
graph.add_edge(node_c, node_d, "edge c")
842+
res = graph.out_edge_indices(node_a)
843+
self.assertEqual([edge_ac, edge_ab], res)
844+
845+
def test_out_edge_indices_invalid_node(self):
846+
graph = rustworkx.PyDiGraph()
847+
res = graph.out_edge_indices(0)
848+
self.assertEqual([], res)
849+
816850
def test_incident_edge_index_map(self):
817851
graph = rustworkx.PyDiGraph()
818852
node_a = graph.add_node(0)

tests/graph/test_edges.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,40 @@ def test_incident_edges_invalid_node(self):
526526
res = graph.incident_edges(42)
527527
self.assertEqual([], res)
528528

529+
def test_in_edge_indices(self):
530+
graph = rustworkx.PyGraph()
531+
node_a = graph.add_node(0)
532+
node_b = graph.add_node(1)
533+
node_c = graph.add_node("c")
534+
node_d = graph.add_node("d")
535+
graph.add_edge(node_a, node_c, "edge a")
536+
graph.add_edge(node_b, node_d, "edge_b")
537+
graph.add_edge(node_c, node_d, "edge c")
538+
res = graph.in_edge_indices(node_d)
539+
self.assertEqual({1, 2}, set(res))
540+
541+
def test_in_edge_indices_invalid_node(self):
542+
graph = rustworkx.PyGraph()
543+
res = graph.in_edge_indices(0)
544+
self.assertEqual([], res)
545+
546+
def test_out_edge_indices(self):
547+
graph = rustworkx.PyGraph()
548+
node_a = graph.add_node(0)
549+
node_b = graph.add_node(1)
550+
node_c = graph.add_node("c")
551+
node_d = graph.add_node("d")
552+
graph.add_edge(node_a, node_c, "edge a")
553+
graph.add_edge(node_b, node_d, "edge_b")
554+
graph.add_edge(node_c, node_d, "edge c")
555+
res = graph.out_edge_indices(node_d)
556+
self.assertEqual({1, 2}, set(res))
557+
558+
def test_out_edge_indices_invalid_node(self):
559+
graph = rustworkx.PyGraph()
560+
res = graph.out_edge_indices(0)
561+
self.assertEqual([], res)
562+
529563
def test_incident_edge_index_map(self):
530564
graph = rustworkx.PyGraph()
531565
node_a = graph.add_node(0)

0 commit comments

Comments
 (0)