Skip to content

Commit 7871655

Browse files
Thijssvincentkoppen
andcommitted
Apply suggestions from code review
Co-authored-by: Vincent Koppen <[email protected]> Signed-off-by: Thijs Baaijen <[email protected]>
1 parent f237fe2 commit 7871655

File tree

3 files changed

+15
-13
lines changed

3 files changed

+15
-13
lines changed

src/power_grid_model_ds/_core/model/graphs/models/base.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: MPL-2.0
44

55
from abc import ABC, abstractmethod
6+
from typing import Generator
67

78
import numpy as np
89
from numpy._typing import NDArray
@@ -35,11 +36,14 @@ def nr_branches(self):
3536
"""Returns the number of branches in the graph"""
3637

3738
@property
38-
@abstractmethod
39-
def all_branches(self) -> list[frozenset[int]]:
39+
def all_branches(self) -> Generator[tuple[int, int], None, None]:
4040
"""Returns all branches in the graph as a list of node pairs (frozensets).
4141
Warning: Depending on graph engine, performance could be slow for large graphs
4242
"""
43+
return (
44+
(self.internal_to_external(source), self.internal_to_external(target))
45+
for source, target in self._all_branches()
46+
)
4347

4448
@abstractmethod
4549
def external_to_internal(self, ext_node_id: int) -> int:
@@ -314,6 +318,9 @@ def _get_components(self, substation_nodes: list[int]) -> list[list[int]]: ...
314318
@abstractmethod
315319
def _find_fundamental_cycles(self) -> list[list[int]]: ...
316320

321+
@abstractmethod
322+
def _all_branches(self) -> Generator[tuple[int, int], None, None]: ...
323+
317324

318325
def _get_branch3_branches(branch3: Branch3Array) -> BranchArray:
319326
node_1 = branch3.node_1.item()

src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: MPL-2.0
44

55
import logging
6+
from typing import Generator
67

78
import rustworkx as rx
89
from rustworkx import NoEdgeBetweenNodes
@@ -32,15 +33,6 @@ def nr_nodes(self):
3233
def nr_branches(self):
3334
return self._graph.num_edges()
3435

35-
@property
36-
def all_branches(self) -> list[frozenset[int]]:
37-
internal_branches = ((source, target) for source, target in self._graph.edge_list())
38-
external_branches = [
39-
frozenset([self.internal_to_external(source), self.internal_to_external(target)])
40-
for source, target in internal_branches
41-
]
42-
return external_branches
43-
4436
@property
4537
def external_ids(self) -> list[int]:
4638
return list(self._external_to_internal.keys())
@@ -116,6 +108,9 @@ def _find_fundamental_cycles(self) -> list[list[int]]:
116108
"""
117109
return find_fundamental_cycles_rustworkx(self._graph)
118110

111+
def _all_branches(self) -> Generator[tuple[int, int], None, None]:
112+
return ((source, target) for source, target in self._graph.edge_list())
113+
119114

120115
class _NodeVisitor(BFSVisitor):
121116
def __init__(self, nodes_to_ignore: list[int]):

tests/unit/model/graphs/test_graph_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_graph_all_branches(graph):
4242
graph.add_node(2)
4343
graph.add_branch(1, 2)
4444

45-
assert [{1, 2}] == graph.all_branches
45+
assert [(1, 2)] == list(graph.all_branches)
4646

4747

4848
def test_graph_all_branches_parallel(graph):
@@ -52,7 +52,7 @@ def test_graph_all_branches_parallel(graph):
5252
graph.add_branch(1, 2)
5353
graph.add_branch(2, 1)
5454

55-
assert [{1, 2}, {1, 2}, {1, 2}] == graph.all_branches
55+
assert [(1, 2), (1, 2), (2, 1)] == list(graph.all_branches)
5656

5757

5858
def test_graph_delete_branch(graph):

0 commit comments

Comments
 (0)