Skip to content

Commit d61f85d

Browse files
Feature: add .all_branches property to graph (#16)
* Add .all_branches property to BaseGraphModel Signed-off-by: Thijs Baaijen <[email protected]> * Apply suggestions from code review Co-authored-by: Vincent Koppen <[email protected]> Signed-off-by: Thijs Baaijen <[email protected]> --------- Signed-off-by: Thijs Baaijen <[email protected]> Co-authored-by: Vincent Koppen <[email protected]>
1 parent d28aba7 commit d61f85d

File tree

4 files changed

+35
-1
lines changed

4 files changed

+35
-1
lines changed

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.0
1+
1.1

src/power_grid_model_ds/_core/model/graphs/models/base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: MPL-2.0
44

55
from abc import ABC, abstractmethod
6+
from typing import Generator
67

78
import numpy as np
89
from numpy._typing import NDArray
@@ -34,6 +35,14 @@ def nr_nodes(self):
3435
def nr_branches(self):
3536
"""Returns the number of branches in the graph"""
3637

38+
@property
39+
def all_branches(self) -> Generator[tuple[int, int], None, None]:
40+
"""Returns all branches in the graph."""
41+
return (
42+
(self.internal_to_external(source), self.internal_to_external(target))
43+
for source, target in self._all_branches()
44+
)
45+
3746
@abstractmethod
3847
def external_to_internal(self, ext_node_id: int) -> int:
3948
"""Convert external node id to internal node id (internal)
@@ -307,6 +316,9 @@ def _get_components(self, substation_nodes: list[int]) -> list[list[int]]: ...
307316
@abstractmethod
308317
def _find_fundamental_cycles(self) -> list[list[int]]: ...
309318

319+
@abstractmethod
320+
def _all_branches(self) -> Generator[tuple[int, int], None, None]: ...
321+
310322

311323
def _get_branch3_branches(branch3: Branch3Array) -> BranchArray:
312324
node_1 = branch3.node_1.item()

src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: MPL-2.0
44

55
import logging
6+
from typing import Generator
67

78
import rustworkx as rx
89
from rustworkx import NoEdgeBetweenNodes
@@ -107,6 +108,9 @@ def _find_fundamental_cycles(self) -> list[list[int]]:
107108
"""
108109
return find_fundamental_cycles_rustworkx(self._graph)
109110

111+
def _all_branches(self) -> Generator[tuple[int, int], None, None]:
112+
return ((source, target) for source, target in self._graph.edge_list())
113+
110114

111115
class _NodeVisitor(BFSVisitor):
112116
def __init__(self, nodes_to_ignore: list[int]):

tests/unit/model/graphs/test_graph_model.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,24 @@ def test_graph_has_branch(graph):
3737
assert not graph.has_branch(1, 3)
3838

3939

40+
def test_graph_all_branches(graph):
41+
graph.add_node(1)
42+
graph.add_node(2)
43+
graph.add_branch(1, 2)
44+
45+
assert [(1, 2)] == list(graph.all_branches)
46+
47+
48+
def test_graph_all_branches_parallel(graph):
49+
graph.add_node(1)
50+
graph.add_node(2)
51+
graph.add_branch(1, 2)
52+
graph.add_branch(1, 2)
53+
graph.add_branch(2, 1)
54+
55+
assert [(1, 2), (1, 2), (2, 1)] == list(graph.all_branches)
56+
57+
4058
def test_graph_delete_branch(graph):
4159
"""Test whether a branch is deleted correctly"""
4260
graph.add_node(1)

0 commit comments

Comments
 (0)