Skip to content

Commit a732914

Browse files
committed
Merge branch 'feat/tmp-remove-nodes' of https://github.com/PowerGridModel/power-grid-model-ds into feat/improve_get_components
2 parents 76fadb3 + c223e83 commit a732914

File tree

4 files changed

+112
-1
lines changed

4 files changed

+112
-1
lines changed

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.0
1+
1.1

src/power_grid_model_ds/_core/model/graphs/models/base.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# SPDX-License-Identifier: MPL-2.0
44

55
from abc import ABC, abstractmethod
6+
from contextlib import contextmanager
7+
from typing import Generator
68

79
import numpy as np
810
from numpy._typing import NDArray
@@ -34,6 +36,14 @@ def nr_nodes(self):
3436
def nr_branches(self):
3537
"""Returns the number of branches in the graph"""
3638

39+
@property
40+
def all_branches(self) -> Generator[tuple[int, int], None, None]:
41+
"""Returns all branches in the graph."""
42+
return (
43+
(self.internal_to_external(source), self.internal_to_external(target))
44+
for source, target in self._all_branches()
45+
)
46+
3747
@abstractmethod
3848
def external_to_internal(self, ext_node_id: int) -> int:
3949
"""Convert external node id to internal node id (internal)
@@ -63,6 +73,14 @@ def has_node(self, node_id: int) -> bool:
6373

6474
return self._has_node(node_id=internal_node_id)
6575

76+
def in_branches(self, node_id: int) -> Generator[tuple[int, int], None, None]:
77+
"""Return all branches that have the node as an endpoint."""
78+
int_node_id = self.external_to_internal(node_id)
79+
internal_edges = self._in_branches(int_node_id=int_node_id)
80+
return (
81+
(self.internal_to_external(source), self.internal_to_external(target)) for source, target in internal_edges
82+
)
83+
6684
def add_node(self, ext_node_id: int, raise_on_fail: bool = True) -> None:
6785
"""Add a node to the graph."""
6886
if self.has_node(ext_node_id):
@@ -164,6 +182,28 @@ def delete_branch3_array(self, branch_array: Branch3Array, raise_on_fail: bool =
164182
branches = _get_branch3_branches(branch3)
165183
self.delete_branch_array(branches, raise_on_fail=raise_on_fail)
166184

185+
@contextmanager
186+
def tmp_remove_nodes(self, nodes: list[int]) -> Generator:
187+
"""Context manager that temporarily removes nodes and their branches from the graph.
188+
Example:
189+
>>> with graph.tmp_remove_nodes([1, 2, 3]):
190+
>>> assert not graph.has_node(1)
191+
>>> assert graph.has_node(1)
192+
In practice, this is useful when you want to e.g. calculate the shortest path between two nodes without
193+
considering certain nodes.
194+
"""
195+
edge_list = []
196+
for node in nodes:
197+
edge_list += list(self.in_branches(node))
198+
self.delete_node(node)
199+
200+
yield
201+
202+
for node in nodes:
203+
self.add_node(node)
204+
for source, target in edge_list:
205+
self.add_branch(source, target)
206+
167207
def get_shortest_path(self, ext_start_node_id: int, ext_end_node_id: int) -> tuple[list[int], int]:
168208
"""Calculate the shortest path between two nodes
169209
@@ -270,6 +310,9 @@ def _branch_is_relevant(self, branch: BranchArray) -> bool:
270310
return branch.is_active.item()
271311
return True
272312

313+
@abstractmethod
314+
def _in_branches(self, int_node_id: int) -> Generator[tuple[int, int], None, None]: ...
315+
273316
@abstractmethod
274317
def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bool = False) -> list[int]: ...
275318

@@ -307,6 +350,9 @@ def _get_components(self, substation_nodes: list[int]) -> list[list[int]]: ...
307350
@abstractmethod
308351
def _find_fundamental_cycles(self) -> list[list[int]]: ...
309352

353+
@abstractmethod
354+
def _all_branches(self) -> Generator[tuple[int, int], None, None]: ...
355+
310356

311357
def _get_branch3_branches(branch3: Branch3Array) -> BranchArray:
312358
node_1 = branch3.node_1.item()

src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: MPL-2.0
44

55
import logging
6+
from typing import Generator
67

78
import rustworkx as rx
89
from rustworkx import NoEdgeBetweenNodes
@@ -99,6 +100,9 @@ def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bo
99100

100101
return connected_nodes
101102

103+
def _in_branches(self, int_node_id: int) -> Generator[tuple[int, int], None, None]:
104+
return ((source, target) for source, target, _ in self._graph.in_edges(int_node_id))
105+
102106
def _find_fundamental_cycles(self) -> list[list[int]]:
103107
"""Find all fundamental cycles in the graph using Rustworkx.
104108
@@ -107,6 +111,9 @@ def _find_fundamental_cycles(self) -> list[list[int]]:
107111
"""
108112
return find_fundamental_cycles_rustworkx(self._graph)
109113

114+
def _all_branches(self) -> Generator[tuple[int, int], None, None]:
115+
return ((source, target) for source, target in self._graph.edge_list())
116+
110117

111118
class _NodeVisitor(BFSVisitor):
112119
def __init__(self, nodes_to_ignore: list[int]):

tests/unit/model/graphs/test_graph_model.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
"""Grid tests"""
66

7+
from collections import Counter
8+
79
import numpy as np
810
import pytest
911
from numpy.testing import assert_array_equal
@@ -37,6 +39,35 @@ def test_graph_has_branch(graph):
3739
assert not graph.has_branch(1, 3)
3840

3941

42+
def test_graph_all_branches(graph):
43+
graph.add_node(1)
44+
graph.add_node(2)
45+
graph.add_branch(1, 2)
46+
47+
assert [(1, 2)] == list(graph.all_branches)
48+
49+
50+
def test_graph_all_branches_parallel(graph):
51+
graph.add_node(1)
52+
graph.add_node(2)
53+
graph.add_branch(1, 2)
54+
graph.add_branch(1, 2)
55+
graph.add_branch(2, 1)
56+
57+
assert [(1, 2), (1, 2), (2, 1)] == list(graph.all_branches)
58+
59+
60+
def test_graph_in_branches(graph):
61+
graph.add_node(1)
62+
graph.add_node(2)
63+
graph.add_branch(1, 2)
64+
graph.add_branch(1, 2)
65+
graph.add_branch(2, 1)
66+
67+
assert [(2, 1), (2, 1), (2, 1)] == list(graph.in_branches(1))
68+
assert [(1, 2), (1, 2), (1, 2)] == list(graph.in_branches(2))
69+
70+
4071
def test_graph_delete_branch(graph):
4172
"""Test whether a branch is deleted correctly"""
4273
graph.add_node(1)
@@ -320,3 +351,30 @@ def test_get_connected_ignore_multiple_nodes(self, graph_with_2_routes):
320351
connected_nodes = graph.get_connected(node_id=1, nodes_to_ignore=[2, 4])
321352

322353
assert {5} == set(connected_nodes)
354+
355+
356+
def test_tmp_remove_nodes(graph_with_2_routes) -> None:
357+
graph = graph_with_2_routes
358+
359+
assert graph.nr_branches == 4
360+
361+
# add parallel branches to test whether they are restored correctly
362+
graph.add_branch(1, 5)
363+
graph.add_branch(5, 1)
364+
365+
assert graph.nr_nodes == 5
366+
assert graph.nr_branches == 6
367+
368+
before_sets = [frozenset(branch) for branch in graph.all_branches]
369+
counter_before = Counter(before_sets)
370+
371+
with graph.tmp_remove_nodes([1, 2]):
372+
assert graph.nr_nodes == 3
373+
assert list(graph.all_branches) == [(5, 4)]
374+
375+
assert graph.nr_nodes == 5
376+
assert graph.nr_branches == 6
377+
378+
after_sets = [frozenset(branch) for branch in graph.all_branches]
379+
counter_after = Counter(after_sets)
380+
assert counter_before == counter_after

0 commit comments

Comments
 (0)