|
3 | 3 | # SPDX-License-Identifier: MPL-2.0 |
4 | 4 |
|
5 | 5 | from abc import ABC, abstractmethod |
| 6 | +from contextlib import contextmanager |
| 7 | +from typing import Generator |
6 | 8 |
|
7 | 9 | import numpy as np |
8 | 10 | from numpy._typing import NDArray |
@@ -34,6 +36,14 @@ def nr_nodes(self): |
34 | 36 | def nr_branches(self): |
35 | 37 | """Returns the number of branches in the graph""" |
36 | 38 |
|
| 39 | + @property |
| 40 | + def all_branches(self) -> Generator[tuple[int, int], None, None]: |
| 41 | + """Returns all branches in the graph.""" |
| 42 | + return ( |
| 43 | + (self.internal_to_external(source), self.internal_to_external(target)) |
| 44 | + for source, target in self._all_branches() |
| 45 | + ) |
| 46 | + |
37 | 47 | @abstractmethod |
38 | 48 | def external_to_internal(self, ext_node_id: int) -> int: |
39 | 49 | """Convert external node id to internal node id (internal) |
@@ -63,6 +73,14 @@ def has_node(self, node_id: int) -> bool: |
63 | 73 |
|
64 | 74 | return self._has_node(node_id=internal_node_id) |
65 | 75 |
|
| 76 | + def in_branches(self, node_id: int) -> Generator[tuple[int, int], None, None]: |
| 77 | + """Return all branches that have the node as an endpoint.""" |
| 78 | + int_node_id = self.external_to_internal(node_id) |
| 79 | + internal_edges = self._in_branches(int_node_id=int_node_id) |
| 80 | + return ( |
| 81 | + (self.internal_to_external(source), self.internal_to_external(target)) for source, target in internal_edges |
| 82 | + ) |
| 83 | + |
66 | 84 | def add_node(self, ext_node_id: int, raise_on_fail: bool = True) -> None: |
67 | 85 | """Add a node to the graph.""" |
68 | 86 | if self.has_node(ext_node_id): |
@@ -164,6 +182,28 @@ def delete_branch3_array(self, branch3_array: Branch3Array, raise_on_fail: bool |
164 | 182 | branches = _get_branch3_branches(branch3) |
165 | 183 | self.delete_branch_array(branches, raise_on_fail=raise_on_fail) |
166 | 184 |
|
| 185 | + @contextmanager |
| 186 | + def tmp_remove_nodes(self, nodes: list[int]) -> Generator: |
| 187 | + """Context manager that temporarily removes nodes and their branches from the graph. |
| 188 | + Example: |
| 189 | + >>> with graph.tmp_remove_nodes([1, 2, 3]): |
| 190 | + >>> assert not graph.has_node(1) |
| 191 | + >>> assert graph.has_node(1) |
| 192 | + In practice, this is useful when you want to e.g. calculate the shortest path between two nodes without |
| 193 | + considering certain nodes. |
| 194 | + """ |
| 195 | + edge_list = [] |
| 196 | + for node in nodes: |
| 197 | + edge_list += list(self.in_branches(node)) |
| 198 | + self.delete_node(node) |
| 199 | + |
| 200 | + yield |
| 201 | + |
| 202 | + for node in nodes: |
| 203 | + self.add_node(node) |
| 204 | + for source, target in edge_list: |
| 205 | + self.add_branch(source, target) |
| 206 | + |
167 | 207 | def get_shortest_path(self, ext_start_node_id: int, ext_end_node_id: int) -> tuple[list[int], int]: |
168 | 208 | """Calculate the shortest path between two nodes |
169 | 209 |
|
@@ -311,6 +351,9 @@ def _branch_is_relevant(self, branch: BranchArray) -> bool: |
311 | 351 | return branch.is_active.item() |
312 | 352 | return True |
313 | 353 |
|
| 354 | + @abstractmethod |
| 355 | + def _in_branches(self, int_node_id: int) -> Generator[tuple[int, int], None, None]: ... |
| 356 | + |
314 | 357 | @abstractmethod |
315 | 358 | def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bool = False) -> list[int]: ... |
316 | 359 |
|
@@ -351,6 +394,9 @@ def _get_components(self, substation_nodes: list[int]) -> list[list[int]]: ... |
351 | 394 | @abstractmethod |
352 | 395 | def _find_fundamental_cycles(self) -> list[list[int]]: ... |
353 | 396 |
|
| 397 | + @abstractmethod |
| 398 | + def _all_branches(self) -> Generator[tuple[int, int], None, None]: ... |
| 399 | + |
354 | 400 |
|
355 | 401 | def _get_branch3_branches(branch3: Branch3Array) -> BranchArray: |
356 | 402 | node_1 = branch3.node_1.item() |
|
0 commit comments