|
3 | 3 | # SPDX-License-Identifier: MPL-2.0 |
4 | 4 |
|
5 | 5 | from abc import ABC, abstractmethod |
| 6 | +from contextlib import contextmanager |
| 7 | +from typing import Generator |
6 | 8 |
|
7 | 9 | import numpy as np |
8 | 10 | from numpy._typing import NDArray |
@@ -164,6 +166,31 @@ def delete_branch3_array(self, branch_array: Branch3Array, raise_on_fail: bool = |
164 | 166 | branches = _get_branch3_branches(branch3) |
165 | 167 | self.delete_branch_array(branches, raise_on_fail=raise_on_fail) |
166 | 168 |
|
| 169 | + @contextmanager |
| 170 | + def tmp_remove_nodes(self, nodes: list[int]) -> Generator: |
| 171 | + """Context manager that temporarily removes nodes and their branches from the graph. |
| 172 | + Example: |
| 173 | + >>> with graph.tmp_remove_nodes([1, 2, 3]): |
| 174 | + >>> assert not graph.has_node(1) |
| 175 | + >>> assert graph.has_node(1) |
| 176 | + In practice, this is useful when you want to e.g. calculate the shortest path between two nodes without |
| 177 | + considering certain nodes. |
| 178 | + """ |
| 179 | + edge_list = [] |
| 180 | + for node in nodes: |
| 181 | + internal_node = self.external_to_internal(node) |
| 182 | + node_edges = [ |
| 183 | + (self.internal_to_external(source), self.internal_to_external(target)) |
| 184 | + for source, target in self._in_edges(internal_node) |
| 185 | + ] |
| 186 | + edge_list += node_edges |
| 187 | + self._delete_node(internal_node) |
| 188 | + yield edge_list |
| 189 | + for node in nodes: |
| 190 | + self.add_node(node) |
| 191 | + for source, target in edge_list: |
| 192 | + self.add_branch(source, target) |
| 193 | + |
167 | 194 | def get_shortest_path(self, ext_start_node_id: int, ext_end_node_id: int) -> tuple[list[int], int]: |
168 | 195 | """Calculate the shortest path between two nodes |
169 | 196 |
|
@@ -270,6 +297,13 @@ def _branch_is_relevant(self, branch: BranchArray) -> bool: |
270 | 297 | return branch.is_active.item() |
271 | 298 | return True |
272 | 299 |
|
| 300 | + @abstractmethod |
| 301 | + def _in_edges(self, internal_node: int) -> list[tuple[int, int]]: |
| 302 | + """Return all edges a node occurs in. |
| 303 | +
|
| 304 | + Return a list of tuples with the source and target node id. These are internal node ids. |
| 305 | + """ |
| 306 | + |
273 | 307 | @abstractmethod |
274 | 308 | def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bool = False) -> list[int]: ... |
275 | 309 |
|
|
0 commit comments