Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions src/power_grid_model_ds/_core/model/graphs/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,49 @@ def get_connected(
nodes_to_ignore=self._externals_to_internals(nodes_to_ignore),
inclusive=inclusive,
)

return self._internals_to_externals(nodes)

def find_first_connected(self, node_id: int, candidate_node_ids: list[int]) -> int:
"""Find the first connected node to the node_id from the candidate_node_ids

Note:
If multiple candidate nodes are connected to the node, the first one found is returned.
There is no guarantee that the same candidate node will be returned each time.

Raises:
MissingNodeError: if no connected node is found
ValueError: if the node_id is in candidate_node_ids
"""
internal_node_id = self.external_to_internal(node_id)
internal_candidates = self._externals_to_internals(candidate_node_ids)
if internal_node_id in internal_candidates:
raise ValueError("node_id cannot be in candidate_node_ids")
return self.internal_to_external(self._find_first_connected(internal_node_id, internal_candidates))

def get_downstream_nodes(self, node_id: int, start_node_ids: list[int], inclusive: bool = False) -> list[int]:
"""Find all nodes downstream of the node_id with respect to the start_node_ids

Example:
given this graph: [1] - [2] - [3] - [4]
>>> graph.get_downstream_nodes(2, [1]) == [3, 4]
>>> graph.get_downstream_nodes(2, [1], inclusive=True) == [2, 3, 4]

args:
node_id: node id to start the search from
start_node_ids: list of node ids considered 'above' the node_id
inclusive: whether to include the given node id in the result
returns:
list of node ids sorted by distance, downstream of to the node id
"""
connected_node = self.find_first_connected(node_id, start_node_ids)
path, _ = self.get_shortest_path(node_id, connected_node)
_, upstream_node, *_ = (
path # path is at least 2 elements long or find_first_connected would have raised an error
)

return self.get_connected(node_id, [upstream_node], inclusive)

def find_fundamental_cycles(self) -> list[list[int]]:
"""Find all fundamental cycles in the graph.
Returns:
Expand Down Expand Up @@ -273,6 +314,9 @@ def _branch_is_relevant(self, branch: BranchArray) -> bool:
@abstractmethod
def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bool = False) -> list[int]: ...

@abstractmethod
def _find_first_connected(self, node_id: int, candidate_node_ids: list[int]) -> int: ...

@abstractmethod
def _has_branch(self, from_node_id, to_node_id) -> bool: ...

Expand Down
22 changes: 21 additions & 1 deletion src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import rustworkx as rx
from rustworkx import NoEdgeBetweenNodes
from rustworkx.visit import BFSVisitor, PruneSearch
from rustworkx.visit import BFSVisitor, PruneSearch, StopSearch

from power_grid_model_ds._core.model.graphs.errors import MissingBranchError, MissingNodeError, NoPathBetweenNodes
from power_grid_model_ds._core.model.graphs.models._rustworkx_search import find_fundamental_cycles_rustworkx
Expand Down Expand Up @@ -99,6 +99,13 @@ def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bo

return connected_nodes

def _find_first_connected(self, node_id: int, candidate_node_ids: list[int]) -> int:
visitor = _NodeFinder(candidate_nodes=candidate_node_ids)
rx.bfs_search(self._graph, [node_id], visitor)
if visitor.found_node is None:
raise MissingNodeError(f"node {node_id} is not connected to any of the candidate nodes")
return visitor.found_node

def _find_fundamental_cycles(self) -> list[list[int]]:
"""Find all fundamental cycles in the graph using Rustworkx.

Expand All @@ -117,3 +124,16 @@ def discover_vertex(self, v):
if v in self.nodes_to_ignore:
raise PruneSearch
self.nodes.append(v)


class _NodeFinder(BFSVisitor):
"""Visitor that stops the search when a candidate node is found"""

def __init__(self, candidate_nodes: list[int]):
self.candidate_nodes = candidate_nodes
self.found_node: int | None = None

def discover_vertex(self, v):
if v in self.candidate_nodes:
self.found_node = v
raise StopSearch
12 changes: 6 additions & 6 deletions src/power_grid_model_ds/_core/model/grids/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def get_nearest_substation_node(self, node_id: int):

def get_downstream_nodes(self, node_id: int, inclusive: bool = False):
"""Get the downstream nodes from a node.
Assuming each node has a single feeding substation and the grid is radial

Example:
given this graph: [1] - [2] - [3] - [4], with 1 being a substation node
Expand All @@ -349,15 +350,14 @@ def get_downstream_nodes(self, node_id: int, inclusive: bool = False):
Returns:
list[int]: The downstream nodes.
"""
substation_node_id = self.get_nearest_substation_node(node_id).id.item()
substation_nodes = self.node.filter(node_type=NodeType.SUBSTATION_NODE.value)

if node_id == substation_node_id:
if node_id in substation_nodes.id:
raise NotImplementedError("get_downstream_nodes is not implemented for substation nodes!")

path_to_substation, _ = self.graphs.active_graph.get_shortest_path(node_id, substation_node_id)
upstream_node = path_to_substation[1]

return self.graphs.active_graph.get_connected(node_id, nodes_to_ignore=[upstream_node], inclusive=inclusive)
return self.graphs.active_graph.get_downstream_nodes(
node_id=node_id, start_node_ids=list(substation_nodes.id), inclusive=inclusive
)

def cache(self, cache_dir: Path, cache_name: str, compress: bool = True):
"""Cache Grid to a folder
Expand Down
61 changes: 23 additions & 38 deletions tests/unit/model/grids/test_grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,57 +18,42 @@ def test_grid_get_nearest_substation_node(basic_grid):


def test_grid_get_nearest_substation_node_no_substation(basic_grid):
"""Test that an error is raised when there is no substation connected to the node"""
substation_node = basic_grid.node.get(node_type=NodeType.SUBSTATION_NODE.value)
basic_grid.delete_node(substation_node)

with pytest.raises(RecordDoesNotExist):
basic_grid.get_nearest_substation_node(node_id=103)


def test_get_downstream_nodes(basic_grid: Grid):
"""Test that get_downstream_nodes returns the expected nodes."""
# Move the open line to be able to test sorting of nodes by distance correctly
basic_grid.make_active(basic_grid.line.get(203))
basic_grid.make_inactive(basic_grid.link.get(601))
downstream_nodes = basic_grid.get_downstream_nodes(node_id=102)
assert downstream_nodes[-1] == 104 # Furthest away
assert {103, 104, 106} == set(downstream_nodes)
class TestGetDownstreamNodes:
def test_get_downstream_nodes(self):
grid = Grid.from_txt(["S1 11", "S1 2", "2 3", "3 5", "5 6", "2 4", "4 99", "99 100"])
downstream_nodes = grid.get_downstream_nodes(node_id=3)
assert [5, 6] == downstream_nodes

downstream_nodes = basic_grid.get_downstream_nodes(node_id=102, inclusive=True)
assert downstream_nodes[0] == 102
assert downstream_nodes[-1] == 104
assert {102, 103, 104, 106} == set(downstream_nodes)
def test_get_downstream_nodes_from_substation_node(self):
grid = Grid.from_txt(["S1 11", "S1 2", "2 3", "3 5", "5 6", "2 4", "4 99", "99 100"])
with pytest.raises(NotImplementedError):
grid.get_downstream_nodes(node_id=1)


def test_get_downstream_nodes_from_substation_node(basic_grid):
"""Test that get_downstream_nodes raises the expected error when
the input node is a substation_node."""
substation_node = basic_grid.node.get(node_type=NodeType.SUBSTATION_NODE.value).record
class TestGetBranchesInPath:
def test_get_branches_in_path(self, basic_grid):
branches = basic_grid.get_branches_in_path([106, 102, 101])
np.testing.assert_array_equal(branches.id, [301, 201])

with pytest.raises(NotImplementedError):
basic_grid.get_downstream_nodes(node_id=substation_node.id)
def test_get_branches_in_path_inactive(self, basic_grid):
branches = basic_grid.get_branches_in_path([101, 102, 103, 104, 105])
# branch 203 is the normally open point should not be in the result
np.testing.assert_array_equal(branches.id, [201, 202, 204, 601])

def test_get_branches_in_path_one_node(self, basic_grid):
branches = basic_grid.get_branches_in_path([106])
assert 0 == branches.size

def test_get_branches_in_path(basic_grid):
branches = basic_grid.get_branches_in_path([106, 102, 101])
np.testing.assert_array_equal(branches.id, [301, 201])


def test_get_branches_in_path_inactive(basic_grid):
branches = basic_grid.get_branches_in_path([101, 102, 103, 104, 105])
# branch 203 is the normally open point should not be in the result
np.testing.assert_array_equal(branches.id, [201, 202, 204, 601])


def test_get_branches_in_path_one_node(basic_grid):
branches = basic_grid.get_branches_in_path([106])
assert 0 == branches.size


def test_get_branches_in_path_empty_path(basic_grid):
branches = basic_grid.get_branches_in_path([])
assert 0 == branches.size
def test_get_branches_in_path_empty_path(self, basic_grid):
branches = basic_grid.get_branches_in_path([])
assert 0 == branches.size


def test_component_three_winding_transformer(grid_with_3wt):
Expand Down