Skip to content

Commit cbcad42

Browse files
feat: improve downstream nodes performance with local search
Signed-off-by: jaapschoutenalliander <[email protected]>
1 parent 13c8664 commit cbcad42

File tree

10 files changed

+156
-127
lines changed

10 files changed

+156
-127
lines changed

src/power_grid_model_ds/_core/model/graphs/models/base.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,26 @@ def get_connected(
235235
nodes_to_ignore=self._externals_to_internals(nodes_to_ignore),
236236
inclusive=inclusive,
237237
)
238+
238239
return self._internals_to_externals(nodes)
239240

241+
def get_downstream_nodes(self, node_id: int, stop_node_ids: list[int], inclusive: bool = False) -> list[int]:
242+
"""Find all nodes connected to the node_id
243+
args:
244+
node_id: node id to start the search from
245+
stop_node_ids: list of node ids to stop the search at
246+
inclusive: whether to include the given node id in the result
247+
returns:
248+
list of node ids sorted by distance, downstream of to the node id
249+
"""
250+
downstream_nodes = self._get_downstream_nodes(
251+
node_id=self.external_to_internal(node_id),
252+
stop_node_ids=self._externals_to_internals(stop_node_ids),
253+
inclusive=inclusive,
254+
)
255+
256+
return self._internals_to_externals(downstream_nodes)
257+
240258
def find_fundamental_cycles(self) -> list[list[int]]:
241259
"""Find all fundamental cycles in the graph.
242260
Returns:
@@ -273,6 +291,9 @@ def _branch_is_relevant(self, branch: BranchArray) -> bool:
273291
@abstractmethod
274292
def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bool = False) -> list[int]: ...
275293

294+
@abstractmethod
295+
def _get_downstream_nodes(self, node_id: int, stop_node_ids: list[int], inclusive: bool = False) -> list[int]: ...
296+
276297
@abstractmethod
277298
def _has_branch(self, from_node_id, to_node_id) -> bool: ...
278299

src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,15 @@ def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bo
9999

100100
return connected_nodes
101101

102+
def _get_downstream_nodes(self, node_id: int, stop_node_ids: list[int], inclusive: bool = False) -> list[int]:
103+
visitor = _NodeVisitor(stop_node_ids)
104+
rx.bfs_search(self._graph, [node_id], visitor)
105+
connected_nodes = visitor.nodes
106+
path_to_substation, _ = self._get_shortest_path(node_id, visitor.discovered_nodes_to_ignore[0])
107+
if inclusive:
108+
_ = path_to_substation.pop(0)
109+
return [node for node in connected_nodes if node not in path_to_substation]
110+
102111
def _find_fundamental_cycles(self) -> list[list[int]]:
103112
"""Find all fundamental cycles in the graph using Rustworkx.
104113
@@ -112,8 +121,10 @@ class _NodeVisitor(BFSVisitor):
112121
def __init__(self, nodes_to_ignore: list[int]):
113122
self.nodes_to_ignore = nodes_to_ignore
114123
self.nodes: list[int] = []
124+
self.discovered_nodes_to_ignore: list[int] = []
115125

116126
def discover_vertex(self, v):
117127
if v in self.nodes_to_ignore:
128+
self.discovered_nodes_to_ignore.append(v)
118129
raise PruneSearch
119130
self.nodes.append(v)

src/power_grid_model_ds/_core/model/grids/base.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ def get_nearest_substation_node(self, node_id: int):
331331

332332
def get_downstream_nodes(self, node_id: int, inclusive: bool = False):
333333
"""Get the downstream nodes from a node.
334+
Assuming each node has a single feeding substation and the grid is radial
334335
335336
Example:
336337
given this graph: [1] - [2] - [3] - [4], with 1 being a substation node
@@ -349,15 +350,14 @@ def get_downstream_nodes(self, node_id: int, inclusive: bool = False):
349350
Returns:
350351
list[int]: The downstream nodes.
351352
"""
352-
substation_node_id = self.get_nearest_substation_node(node_id).id.item()
353+
substation_nodes = self.node.filter(node_type=NodeType.SUBSTATION_NODE.value)
353354

354-
if node_id == substation_node_id:
355+
if node_id in substation_nodes.id:
355356
raise NotImplementedError("get_downstream_nodes is not implemented for substation nodes!")
356357

357-
path_to_substation, _ = self.graphs.active_graph.get_shortest_path(node_id, substation_node_id)
358-
upstream_node = path_to_substation[1]
359-
360-
return self.graphs.active_graph.get_connected(node_id, nodes_to_ignore=[upstream_node], inclusive=inclusive)
358+
return self.graphs.active_graph.get_downstream_nodes(
359+
node_id=node_id, stop_node_ids=list(substation_nodes.id), inclusive=inclusive
360+
)
361361

362362
def cache(self, cache_dir: Path, cache_name: str, compress: bool = True):
363363
"""Cache Grid to a folder

tests/performance/_constants.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,18 @@
66
"dtype = [('id', '<i8'), ('test_int', '<i8'), ('test_float', '<f8'), ('test_str', '<U50'), ('test_bool', '?')]; "
77
)
88

9-
SETUP_CODES = {
10-
"structured": "import numpy as np;" + NUMPY_DTYPE + "input_array = np.zeros({array_size}, dtype=dtype)",
11-
"rec": "import numpy as np;" + NUMPY_DTYPE + "input_array = np.recarray(({array_size},),dtype=dtype)",
12-
"fancy": "from tests.conftest import FancyTestArray; input_array=FancyTestArray.zeros({array_size});"
13-
+ "import numpy as np;input_array.id = np.arange({array_size})",
9+
ARRAY_SETUP_CODES = {
10+
"structured": "import numpy as np;" + NUMPY_DTYPE + "input_array = np.zeros({size}, dtype=dtype)",
11+
"rec": "import numpy as np;" + NUMPY_DTYPE + "input_array = np.recarray(({size},),dtype=dtype)",
12+
"fancy": "from tests.conftest import FancyTestArray; input_array=FancyTestArray.zeros({size});"
13+
+ "import numpy as np;input_array.id = np.arange({size})",
1414
}
1515

1616
GRAPH_SETUP_CODES = {
17-
"rustworkx": "from power_grid_model_ds.model.grids.base import Grid;"
18-
+ "from power_grid_model_ds.data_source.generator.grid_generators import RadialGridGenerator;"
19-
+ "from power_grid_model_ds.model.graphs.models import RustworkxGraphModel;"
20-
+ "grid=RadialGridGenerator(nr_nodes={graph_size}, grid_class=Grid, graph_model=RustworkxGraphModel).run()",
17+
"rustworkx": "from power_grid_model_ds import Grid;"
18+
+ "from power_grid_model_ds.generators import RadialGridGenerator;"
19+
+ "from power_grid_model_ds.graph_models import RustworkxGraphModel;"
20+
+ "grid=RadialGridGenerator(nr_nodes={size}, grid_class=Grid, graph_model=RustworkxGraphModel).run()",
2121
}
2222

2323
SINGLE_REPEATS = 1000

tests/performance/_helpers.py

Lines changed: 29 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -4,100 +4,57 @@
44

55
import inspect
66
import timeit
7-
from typing import Generator
7+
from itertools import product
8+
from typing import Generator, Union
89

9-
from tests.performance._constants import GRAPH_SETUP_CODES, SETUP_CODES
10-
11-
12-
def do_performance_test(code_to_test: str | dict[str, str], array_sizes: list[int], repeats: int):
13-
"""Run the performance test for the given code."""
1410

11+
def do_performance_test(
12+
code_to_test: Union[str, dict[str, str], list[str]],
13+
size_list: list[int],
14+
repeats: int,
15+
setup_codes: dict[str, str],
16+
):
17+
"""Generalized performance test runner."""
1518
print(f"{'-' * 20} {inspect.stack()[1][3]} {'-' * 20}")
1619

17-
for array_size in array_sizes:
20+
for size in size_list:
21+
formatted_setup_codes = {key: code.format(size=size) for key, code in setup_codes.items()}
1822
if isinstance(code_to_test, dict):
19-
code_to_test_list = [code_to_test[variant].format(array_size=array_size) for variant in SETUP_CODES]
20-
else:
21-
code_to_test_list = [code_to_test.format(array_size=array_size)] * len(SETUP_CODES)
22-
print(f"\n\tArray size: {array_size}\n")
23-
setup_codes = [setup_code.format(array_size=array_size) for setup_code in SETUP_CODES.values()]
24-
timings = _get_timings(setup_codes, code_to_test_list, repeats)
25-
26-
if code_to_test == "pass":
27-
_print_timings(timings, list(SETUP_CODES.keys()), setup_codes)
23+
code_to_test_list = [code_to_test[variant].format(size=size) for variant in setup_codes]
24+
test_generator = zip(formatted_setup_codes.items(), code_to_test_list)
25+
elif isinstance(code_to_test, list):
26+
code_to_test_list = [code.format(size=size) for code in code_to_test]
27+
test_generator = product(formatted_setup_codes.items(), code_to_test_list)
2828
else:
29-
_print_timings(timings, list(SETUP_CODES.keys()), code_to_test_list)
30-
print()
29+
test_generator = product(formatted_setup_codes.items(), [code_to_test.format(size=size)])
3130

31+
print(f"\n\tsize: {size}\n")
3232

33-
def do_graph_test(code_to_test: str | dict[str, str], graph_sizes: list[int], repeats: int):
34-
"""Run the performance test for the given code."""
33+
timings = _get_timings(test_generator, repeats=repeats)
34+
_print_timings(timings)
3535

36-
print(f"{'-' * 20} {inspect.stack()[1][3]} {'-' * 20}")
37-
38-
for graph_size in graph_sizes:
39-
if isinstance(code_to_test, dict):
40-
code_to_test_list = [code_to_test[variant] for variant in GRAPH_SETUP_CODES]
41-
else:
42-
code_to_test_list = [code_to_test] * len(GRAPH_SETUP_CODES)
43-
print(f"\n\tGraph size: {graph_size}\n")
44-
setup_codes = [setup_code.format(graph_size=graph_size) for setup_code in GRAPH_SETUP_CODES.values()]
45-
timings = _get_timings(setup_codes, code_to_test_list, repeats)
46-
47-
if code_to_test == "pass":
48-
_print_graph_timings(timings, list(GRAPH_SETUP_CODES.keys()), setup_codes)
49-
else:
50-
_print_graph_timings(timings, list(GRAPH_SETUP_CODES.keys()), code_to_test_list)
5136
print()
5237

5338

54-
def _print_test_code(code: str | dict[str, str], repeats: int):
55-
print(f"{'-' * 40}")
56-
if isinstance(code, dict):
57-
for variant, code_variant in code.items():
58-
print(f"{variant}")
59-
print(f"\t{code_variant} (x {repeats})")
60-
return
61-
print(f"{code} (x {repeats})")
62-
63-
64-
def _print_graph_timings(timings: Generator, graph_types: list[str], code_list: list[str]):
65-
for graph_type, timing, code in zip(graph_types, timings, code_list):
66-
if ";" in code:
67-
code = code.split(";")[-1]
68-
69-
code = code.replace("\n", " ").replace("\t", " ")
70-
code = f"{graph_type}: " + code
71-
72-
if isinstance(timing, Exception):
73-
print(f"\t\t{code.ljust(100)} | Not supported")
74-
continue
75-
print(f"\t\t{code.ljust(100)} | {sum(timing):.2f}s")
76-
77-
78-
def _print_timings(timings: Generator, array_types: list[str], code_list: list[str]):
79-
for array, timing, code in zip(array_types, timings, code_list):
80-
if ";" in code:
81-
code = code.split(";")[-1]
82-
83-
code = code.replace("\n", " ").replace("\t", " ")
84-
array_name = f"{array}_array"
85-
code = code.replace("input_array", array_name)
39+
def _print_timings(timings: Generator):
40+
for key, code, timing in timings:
41+
code = code.split(";")[-1].replace("\n", " ").replace("\t", " ")
42+
code = f"{key}: {code}"
8643

8744
if isinstance(timing, Exception):
8845
print(f"\t\t{code.ljust(100)} | Not supported")
8946
continue
9047
print(f"\t\t{code.ljust(100)} | {sum(timing):.2f}s")
9148

9249

93-
def _get_timings(setup_codes: list[str], test_codes: list[str], repeats: int):
50+
def _get_timings(test_generator, repeats: int):
9451
"""Return a generator with the timings for each array type."""
95-
for setup_code, test_code in zip(setup_codes, test_codes):
52+
for (key, setup_code), test_code in test_generator:
9653
if test_code == "pass":
97-
yield timeit.repeat(setup_code, number=repeats)
54+
yield key, "intialise", timeit.repeat(setup_code, number=repeats)
9855
else:
9956
try:
100-
yield timeit.repeat(test_code, setup_code, number=repeats)
57+
yield key, test_code, timeit.repeat(test_code, setup_code, number=repeats)
10158
# pylint: disable=broad-exception-caught
10259
except Exception as error: # noqa
103-
yield error
60+
yield key, test_code, error

tests/performance/array_performance_tests.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,24 @@
1111

1212
import logging
1313

14-
from tests.performance._constants import ARRAY_SIZES_LARGE, ARRAY_SIZES_SMALL, LOOP_REPEATS, SINGLE_REPEATS
14+
from tests.performance._constants import (
15+
ARRAY_SETUP_CODES,
16+
ARRAY_SIZES_LARGE,
17+
ARRAY_SIZES_SMALL,
18+
LOOP_REPEATS,
19+
SINGLE_REPEATS,
20+
)
1521
from tests.performance._helpers import do_performance_test
1622

1723
logging.basicConfig(level=logging.INFO)
1824

1925

2026
def perftest_initialize():
21-
do_performance_test("pass", ARRAY_SIZES_LARGE, SINGLE_REPEATS)
27+
do_performance_test("pass", ARRAY_SIZES_LARGE, SINGLE_REPEATS, ARRAY_SETUP_CODES)
2228

2329

2430
def perftest_slice():
25-
do_performance_test("input_array[0:10]", ARRAY_SIZES_LARGE, SINGLE_REPEATS)
31+
do_performance_test("input_array[0:10]", ARRAY_SIZES_LARGE, SINGLE_REPEATS, ARRAY_SETUP_CODES)
2632

2733

2834
def perftest_set_attr():
@@ -31,77 +37,77 @@ def perftest_set_attr():
3137
"rec": "input_array.id = 1",
3238
"fancy": "input_array.id = 1",
3339
}
34-
do_performance_test(code_to_test, ARRAY_SIZES_LARGE, SINGLE_REPEATS)
40+
do_performance_test(code_to_test, ARRAY_SIZES_LARGE, SINGLE_REPEATS, ARRAY_SETUP_CODES)
3541

3642

3743
def perftest_set_field():
38-
do_performance_test("input_array['id'] = 1", ARRAY_SIZES_LARGE, SINGLE_REPEATS)
44+
do_performance_test("input_array['id'] = 1", ARRAY_SIZES_LARGE, SINGLE_REPEATS, ARRAY_SETUP_CODES)
3945

4046

4147
def perftest_loop_slice_1():
42-
code_to_test = "for i in range({array_size}): input_array[i]"
43-
do_performance_test(code_to_test, ARRAY_SIZES_SMALL, LOOP_REPEATS)
48+
code_to_test = "for i in range({size}): input_array[i]"
49+
do_performance_test(code_to_test, ARRAY_SIZES_SMALL, LOOP_REPEATS, ARRAY_SETUP_CODES)
4450

4551

4652
def perftest_loop_data_slice_1():
4753
code_to_test = {
48-
"structured": "for i in range({array_size}): input_array[i]",
49-
"rec": "for i in range({array_size}): input_array[i]",
50-
"fancy": "for i in range({array_size}): input_array.data[i]",
54+
"structured": "for i in range({size}): input_array[i]",
55+
"rec": "for i in range({size}): input_array[i]",
56+
"fancy": "for i in range({size}): input_array.data[i]",
5157
}
52-
do_performance_test(code_to_test, ARRAY_SIZES_SMALL, LOOP_REPEATS)
58+
do_performance_test(code_to_test, ARRAY_SIZES_SMALL, LOOP_REPEATS, ARRAY_SETUP_CODES)
5359

5460

5561
def perftest_loop_slice():
56-
code_to_test = "for i in range({array_size}): input_array[i:i+1]"
57-
do_performance_test(code_to_test, ARRAY_SIZES_SMALL, LOOP_REPEATS)
62+
code_to_test = "for i in range({size}): input_array[i:i+1]"
63+
do_performance_test(code_to_test, ARRAY_SIZES_SMALL, LOOP_REPEATS, ARRAY_SETUP_CODES)
5864

5965

6066
def perftest_loop_set_field():
61-
code_to_test = "for i in range({array_size}): input_array['id'][i] = 1"
62-
do_performance_test(code_to_test, ARRAY_SIZES_SMALL, LOOP_REPEATS)
67+
code_to_test = "for i in range({size}): input_array['id'][i] = 1"
68+
do_performance_test(code_to_test, ARRAY_SIZES_SMALL, LOOP_REPEATS, ARRAY_SETUP_CODES)
6369

6470

6571
def perftest_loop_get_field():
6672
code_to_test = "for row in input_array: row['id']"
67-
do_performance_test(code_to_test, ARRAY_SIZES_SMALL, LOOP_REPEATS)
73+
do_performance_test(code_to_test, ARRAY_SIZES_SMALL, LOOP_REPEATS, ARRAY_SETUP_CODES)
6874

6975

7076
def perftest_loop_data_get_field():
7177
code_to_test = "for row in input_array.data: row['id']"
72-
do_performance_test(code_to_test, ARRAY_SIZES_SMALL, LOOP_REPEATS)
78+
do_performance_test(code_to_test, ARRAY_SIZES_SMALL, LOOP_REPEATS, ARRAY_SETUP_CODES)
7379

7480

7581
def perftest_loop_get_attr():
7682
code_to_test = "for row in input_array: row.id"
77-
do_performance_test(code_to_test, ARRAY_SIZES_SMALL, 100)
83+
do_performance_test(code_to_test, ARRAY_SIZES_SMALL, 100, ARRAY_SETUP_CODES)
7884

7985

8086
def perftest_fancypy_concat():
8187
code_to_test = {
8288
"structured": "import numpy as np;np.concatenate([input_array, input_array])",
8389
"rec": "import numpy as np;np.concatenate([input_array, input_array])",
84-
"fancy": "import power_grid_model_ds._core.fancypy as fp;fp.concatenate(input_array, input_array)",
90+
"fancy": "import power_grid_model_ds.fancypy as fp;fp.concatenate(input_array, input_array)",
8591
}
86-
do_performance_test(code_to_test, ARRAY_SIZES_LARGE, 100)
92+
do_performance_test(code_to_test, ARRAY_SIZES_LARGE, 100, ARRAY_SETUP_CODES)
8793

8894

8995
def perftest_fancypy_unique():
9096
code_to_test = {
9197
"structured": "import numpy as np;np.unique(input_array)",
9298
"rec": "import numpy as np;np.unique(input_array)",
93-
"fancy": "import power_grid_model_ds._core.fancypy as fp;fp.unique(input_array)",
99+
"fancy": "import power_grid_model_ds.fancypy as fp;fp.unique(input_array)",
94100
}
95-
do_performance_test(code_to_test, ARRAY_SIZES_SMALL, 100)
101+
do_performance_test(code_to_test, ARRAY_SIZES_SMALL, 100, ARRAY_SETUP_CODES)
96102

97103

98104
def perftest_fancypy_sort():
99105
code_to_test = {
100106
"structured": "import numpy as np;np.sort(input_array)",
101107
"rec": "import numpy as np;np.sort(input_array)",
102-
"fancy": "import power_grid_model_ds._core.fancypy as fp;fp.sort(input_array)",
108+
"fancy": "import power_grid_model_ds.fancypy as fp;fp.sort(input_array)",
103109
}
104-
do_performance_test(code_to_test, ARRAY_SIZES_SMALL, 100)
110+
do_performance_test(code_to_test, ARRAY_SIZES_SMALL, 100, ARRAY_SETUP_CODES)
105111

106112

107113
if __name__ == "__main__":

0 commit comments

Comments
 (0)