Skip to content
Merged
74 changes: 47 additions & 27 deletions src/power_grid_model_ds/_core/model/graphs/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""Stores the GraphContainer class"""

import dataclasses
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Generator

Expand Down Expand Up @@ -56,47 +57,68 @@ def empty(cls, graph_model: type[BaseGraphModel] = RustworkxGraphModel) -> "Grap
complete_graph=graph_model(active_only=False),
)

def add_node_array(self, node_array: NodeArray) -> None:
"""Add a node to all graphs"""
for field in dataclasses.fields(self):
graph = getattr(self, field.name)
graph.add_node_array(node_array=node_array, raise_on_fail=False)

def add_node(self, node: NodeArray) -> None:
"""Add a node to all graphs"""
warnings.warn(
"add_node is deprecated and will be removed in a future release, use add_node_array instead",
category=DeprecationWarning,
stacklevel=2,
)
self.add_node_array(node_array=node)

def add_branch_array(self, branch_array: BranchArray) -> None:
"""Add a branch to all graphs"""
for field in self.graph_attributes:
graph = getattr(self, field.name)
graph.add_branch_array(branch_array=branch_array)

def add_branch(self, branch: BranchArray) -> None:
"""Add a branch to all graphs"""
warnings.warn(
"add_branch is deprecated and will be removed in a future release, use add_branch_array instead",
category=DeprecationWarning,
stacklevel=2,
)
self.add_branch_array(branch_array=branch)

def add_branch3_array(self, branch3_array: Branch3Array) -> None:
"""Add a branch to all graphs"""
for field in self.graph_attributes:
graph = getattr(self, field.name)
graph.add_branch_array(branch_array=branch)
setattr(self, field.name, graph)
graph.add_branch3_array(branch3_array=branch3_array)

def add_branch3(self, branch: Branch3Array) -> None:
"""Add a branch to all graphs"""
for field in self.graph_attributes:
warnings.warn(
"add_branch3 is deprecated and will be removed in a future release, use add_branch3_array instead",
category=DeprecationWarning,
stacklevel=2,
)
self.add_branch3_array(branch3_array=branch)

def delete_node(self, node: NodeArray) -> None:
"""Remove a node from all graphs"""
for field in dataclasses.fields(self):
graph = getattr(self, field.name)
graph.add_branch3_array(branch3_array=branch)
setattr(self, field.name, graph)
graph.delete_node_array(node_array=node)

def delete_branch(self, branch: BranchArray) -> None:
"""Remove a branch from all graphs"""
for field in self.graph_attributes:
graph = getattr(self, field.name)
graph.delete_branch_array(branch_array=branch)
setattr(self, field.name, graph)

def delete_branch3(self, branch: Branch3Array) -> None:
"""Remove a branch from all graphs"""
for field in self.graph_attributes:
graph = getattr(self, field.name)
graph.delete_branch3_array(branch3_array=branch)
setattr(self, field.name, graph)

def add_node(self, node: NodeArray) -> None:
"""Add a node to all graphs"""
for field in dataclasses.fields(self):
graph = getattr(self, field.name)
graph.add_node_array(node_array=node, raise_on_fail=False)
setattr(self, field.name, graph)

def delete_node(self, node: NodeArray) -> None:
"""Remove a node from all graphs"""
for field in dataclasses.fields(self):
graph = getattr(self, field.name)
graph.delete_node_array(node_array=node)
setattr(self, field.name, graph)

def make_active(self, branch: BranchArray) -> None:
"""Add branch to all active_only graphs"""
Expand All @@ -107,7 +129,6 @@ def make_active(self, branch: BranchArray) -> None:
graph = getattr(self, field.name)
if graph.active_only:
graph.add_branch(from_ext_node_id=from_node, to_ext_node_id=to_node)
setattr(self, field.name, graph)

def make_inactive(self, branch: BranchArray) -> None:
"""Remove a branch from all active_only graphs"""
Expand All @@ -118,7 +139,6 @@ def make_inactive(self, branch: BranchArray) -> None:
graph = getattr(self, field.name)
if graph.active_only:
graph.delete_branch(from_ext_node_id=from_node, to_ext_node_id=to_node)
setattr(self, field.name, graph)

@classmethod
def from_arrays(cls, arrays: "Grid") -> "GraphContainer":
Expand All @@ -142,9 +162,9 @@ def _validate_branches(arrays: "Grid") -> None:
raise RecordDoesNotExist(f"Found invalid .to_node values in {array.__class__.__name__}")

def _append(self, array: FancyArray) -> None:
if isinstance(array, NodeArray):
self.add_node_array(array)
if isinstance(array, BranchArray):
self.add_branch(array)
self.add_branch_array(array)
if isinstance(array, Branch3Array):
self.add_branch3(array)
if isinstance(array, NodeArray):
self.add_node(array)
self.add_branch3_array(array)
26 changes: 19 additions & 7 deletions src/power_grid_model_ds/_core/model/graphs/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,9 @@ def delete_node(self, ext_node_id: int, raise_on_fail: bool = True) -> None:

def add_node_array(self, node_array: NodeArray, raise_on_fail: bool = True) -> None:
"""Add all nodes in the node array to the graph."""
for node in node_array:
self.add_node(ext_node_id=node.id.item(), raise_on_fail=raise_on_fail)
if raise_on_fail and any(self.has_node(x) for x in node_array["id"]):
raise GraphError("At least one node id already exists in the Graph.")
self._add_nodes(node_array["id"].tolist())

def delete_node_array(self, node_array: NodeArray, raise_on_fail: bool = True) -> None:
"""Delete all nodes in node_array from the graph"""
Expand Down Expand Up @@ -162,9 +163,14 @@ def delete_branch(self, from_ext_node_id: int, to_ext_node_id: int, raise_on_fai

def add_branch_array(self, branch_array: BranchArray) -> None:
"""Add all branches in the branch array to the graph."""
for branch in branch_array:
if self._branch_is_relevant(branch):
self.add_branch(branch.from_node.item(), branch.to_node.item())
if self.active_only:
branch_array = branch_array[branch_array.is_active]
if not branch_array.size:
return

from_node_ids = self._externals_to_internals(branch_array["from_node"].tolist())
to_node_ids = self._externals_to_internals(branch_array["to_node"].tolist())
self._add_branches(from_node_ids, to_node_ids)

def add_branch3_array(self, branch3_array: Branch3Array) -> None:
"""Add all branch3s in the branch3 array to the graph."""
Expand Down Expand Up @@ -333,7 +339,7 @@ def from_arrays(cls, arrays: "Grid", active_only=False) -> "BaseGraphModel":
"""Build from arrays"""
new_graph = cls(active_only=active_only)

new_graph.add_node_array(node_array=arrays.node)
new_graph.add_node_array(node_array=arrays.node, raise_on_fail=False)
new_graph.add_branch_array(arrays.branches)
new_graph.add_branch3_array(arrays.three_winding_transformer)

Expand Down Expand Up @@ -371,11 +377,17 @@ def _has_node(self, node_id) -> bool: ...
@abstractmethod
def _add_node(self, ext_node_id: int) -> None: ...

@abstractmethod
def _add_nodes(self, ext_node_ids: list[int]) -> None: ...

@abstractmethod
def _delete_node(self, node_id: int): ...

@abstractmethod
def _add_branch(self, from_node_id, to_node_id) -> None: ...
def _add_branch(self, from_node_id: int, to_node_id: int) -> None: ...

@abstractmethod
def _add_branches(self, from_node_ids: list[int], to_node_ids: list[int]) -> None: ...

@abstractmethod
def _delete_branch(self, from_node_id, to_node_id) -> None:
Expand Down
10 changes: 10 additions & 0 deletions src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ def _add_node(self, ext_node_id: int):
self._external_to_internal[ext_node_id] = graph_node_id
self._internal_to_external[graph_node_id] = ext_node_id

def _add_nodes(self, ext_node_ids: list[int]) -> None:
graph_node_ids = self._graph.add_nodes_from(ext_node_ids)
for ext_node_id, graph_node_id in zip(ext_node_ids, graph_node_ids):
self._external_to_internal[ext_node_id] = graph_node_id
self._internal_to_external[graph_node_id] = ext_node_id

def _delete_node(self, node_id: int):
self._graph.remove_node(node_id)
external_node_id = self._internal_to_external.pop(node_id)
Expand All @@ -66,6 +72,10 @@ def _has_node(self, node_id: int) -> bool:
def _add_branch(self, from_node_id: int, to_node_id: int):
self._graph.add_edge(from_node_id, to_node_id, None)

def _add_branches(self, from_node_ids: list[int], to_node_ids: list[int]):
edge_list = [(from_node_id, to_node_id, None) for from_node_id, to_node_id in zip(from_node_ids, to_node_ids)]
self._graph.add_edges_from(edge_list)

def _delete_branch(self, from_node_id: int, to_node_id: int) -> None:
try:
self._graph.remove_edge(from_node_id, to_node_id)
Expand Down
10 changes: 5 additions & 5 deletions src/power_grid_model_ds/_core/model/grids/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ def append(self, array: FancyArray, check_max_id: bool = True):
check_max_id (bool, optional): Whether to check if the array id is the maximum id. Defaults to True.
"""
self._append(array, check_max_id=check_max_id) # noqa
for row in array:
# pylint: disable=protected-access
self.graphs._append(row)

# pylint: disable=protected-access
self.graphs._append(array)

def add_branch(self, branch: BranchArray) -> None:
"""Add a branch to the grid
Expand All @@ -211,7 +211,7 @@ def add_branch(self, branch: BranchArray) -> None:
branch (BranchArray): The branch to add
"""
self._append(array=branch)
self.graphs.add_branch(branch=branch)
self.graphs.add_branch_array(branch_array=branch)

logging.debug(f"added branch {branch.id} from {branch.from_node} to {branch.to_node}")

Expand Down Expand Up @@ -243,7 +243,7 @@ def add_node(self, node: NodeArray) -> None:
node (NodeArray): The node to add
"""
self._append(array=node)
self.graphs.add_node(node=node)
self.graphs.add_node_array(node_array=node)
logging.debug(f"added rail {node.id}")

def delete_node(self, node: NodeArray) -> None:
Expand Down
36 changes: 34 additions & 2 deletions tests/performance/grid_performance_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,37 @@
# pylint: disable=missing-function-docstring


def test_get_downstream_nodes_performance():
def perf_test_add_nodes():
setup_code = {
"grid": "from power_grid_model_ds import Grid;"
+ "from power_grid_model_ds._core.model.arrays import NodeArray;"
+ "grid = Grid.empty();"
+ "nodes = NodeArray.zeros({size});"
}

code_to_test = ["grid.append(nodes);"]

do_performance_test(code_to_test, [10, 200, 1000], 100, setup_code)


def perf_test_add_lines():
setup_code = {
"grid": "from power_grid_model_ds import Grid;"
+ "from power_grid_model_ds._core.model.arrays import NodeArray, LineArray;"
+ "grid = Grid.empty();"
+ "nodes = NodeArray.zeros({size});"
+ "grid.append(nodes);"
+ "lines = LineArray.zeros({size});"
+ "lines.from_node = nodes.id;"
+ "lines.to_node = nodes.id;"
}

code_to_test = ["grid.append(lines);"]

do_performance_test(code_to_test, [10, 200, 1000], 100, setup_code)


def perf_test_get_downstream_nodes_performance():
setup_code = {
"grid": "import numpy as np;"
+ "from power_grid_model_ds.enums import NodeType;"
Expand All @@ -27,4 +57,6 @@ def test_get_downstream_nodes_performance():


if __name__ == "__main__":
test_get_downstream_nodes_performance()
perf_test_get_downstream_nodes_performance()
perf_test_add_nodes()
perf_test_add_lines()
19 changes: 5 additions & 14 deletions tests/unit/model/graphs/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from power_grid_model_ds._core.model.arrays import NodeArray, ThreeWindingTransformerArray
from power_grid_model_ds._core.model.arrays.base.errors import RecordDoesNotExist
from power_grid_model_ds._core.model.graphs.container import GraphContainer
from power_grid_model_ds._core.model.graphs.errors import GraphError
from power_grid_model_ds._core.model.grids.base import Grid

# pylint: disable=missing-function-docstring
Expand Down Expand Up @@ -122,23 +121,15 @@ def test_from_arrays_partially_active_three_winding(basic_grid: Grid):
assert basic_grid.graphs.complete_graph.nr_nodes == graphs.complete_graph.nr_nodes
assert basic_grid.graphs.complete_graph.nr_branches == 6 + 3

# Implicitly test that the correct branches are added
# Current implementation does not have a has_branch method.
basic_grid.graphs.complete_graph.delete_branch(1000, 1001)
basic_grid.graphs.complete_graph.delete_branch(1000, 1002)
basic_grid.graphs.complete_graph.delete_branch(1001, 1002)
basic_grid.graphs.active_graph.has_branch(1000, 1002)
basic_grid.graphs.active_graph.has_branch(1001, 1002)

assert basic_grid.graphs.active_graph.nr_nodes == graphs.active_graph.nr_nodes
assert basic_grid.graphs.active_graph.nr_branches == 5 + 1

# Implicitly test that the correct branches are added
# Current implementation does not have a has_branch method.
basic_grid.graphs.active_graph.delete_branch(1000, 1001)
with pytest.raises(GraphError):
basic_grid.graphs.active_graph.delete_branch(1000, 1002)

with pytest.raises(GraphError):
basic_grid.graphs.active_graph.delete_branch(1001, 1002)
basic_grid.graphs.active_graph.has_branch(1000, 1001)
assert not basic_grid.graphs.active_graph.has_branch(1000, 1002)
assert not basic_grid.graphs.active_graph.has_branch(1001, 1002)


def test_from_arrays_invalid_arrays(basic_grid: Grid):
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/model/graphs/test_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest
from numpy.testing import assert_array_equal

from power_grid_model_ds._core.model.graphs.errors import GraphError
from power_grid_model_ds._core.model.graphs.models.base import BaseGraphModel
from power_grid_model_ds._core.model.grids.base import Grid
from power_grid_model_ds.errors import MissingBranchError, MissingNodeError, NoPathBetweenNodes
Expand All @@ -30,6 +31,11 @@ def test_graph_add_node_and_branch(self, graph: BaseGraphModel):
assert 2 == graph.nr_nodes
assert 1 == graph.nr_branches

def test_add_node_already_there(self, graph: BaseGraphModel):
graph.add_node(1)
with pytest.raises(GraphError):
graph.add_node(1)

def test_add_invalid_branch(self, graph: BaseGraphModel):
graph.add_node(1)
graph.add_node(2)
Expand Down
16 changes: 9 additions & 7 deletions tests/unit/model/test_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,15 +156,17 @@ def test_id_counter():


def test_branches(grid: Grid):
node = NodeArray.zeros(10, empty_id=False)
grid.append(node)
grid.append(LineArray.zeros(10))
grid.append(TransformerArray.zeros(10))
grid.append(LinkArray.zeros(10))
branches = grid.branches
nodes = NodeArray.zeros(10)
grid.append(nodes)

for branch_class in (LineArray, TransformerArray, LinkArray):
branches = branch_class.zeros(10)
branches.from_node = nodes.id
branches.to_node = list(reversed(nodes.id.tolist()))
grid.append(branches)

expected_ids = np.concatenate((grid.line.id, grid.transformer.id, grid.link.id))
assert set(expected_ids) == set(branches.id)
assert set(expected_ids) == set(grid.branches.id)


def test_delete_node_without_additional_properties(basic_grid: Grid):
Expand Down