Skip to content

Commit 0c927ce

Browse files
chore: add typing to tests
Signed-off-by: jaapschoutenalliander <[email protected]>
1 parent dbe1714 commit 0c927ce

File tree

20 files changed

+230
-205
lines changed

20 files changed

+230
-205
lines changed

src/power_grid_model_ds/_core/model/containers/grid_protocol.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: MPL-2.0
44
"""This file contains the Grid protocol defining the minimal arrays contained in a grid"""
55

6+
from abc import abstractmethod
67
from typing import Protocol
78

89
from power_grid_model_ds._core.model.arrays import (
@@ -18,5 +19,13 @@ class MinimalGridArrays(Protocol):
1819

1920
node: NodeArray
2021
three_winding_transformer: ThreeWindingTransformerArray
21-
branches: BranchArray
22-
branch_arrays: list[BranchArray]
22+
23+
@property
24+
@abstractmethod
25+
def branches(self) -> BranchArray:
26+
"""Converts all branch arrays into a single BranchArray."""
27+
28+
@property
29+
@abstractmethod
30+
def branch_arrays(self) -> list[BranchArray]:
31+
"""Returns all branch arrays"""

src/power_grid_model_ds/_core/model/grids/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from power_grid_model_ds._core.model.arrays.base.array import FancyArray
3636
from power_grid_model_ds._core.model.arrays.base.errors import RecordDoesNotExist
3737
from power_grid_model_ds._core.model.containers.base import FancyArrayContainer
38+
from power_grid_model_ds._core.model.containers.grid_protocol import MinimalGridArrays
3839
from power_grid_model_ds._core.model.enums.nodes import NodeType
3940
from power_grid_model_ds._core.model.graphs.container import GraphContainer
4041
from power_grid_model_ds._core.model.graphs.models import RustworkxGraphModel
@@ -51,7 +52,7 @@
5152

5253

5354
@dataclass
54-
class Grid(FancyArrayContainer):
55+
class Grid(FancyArrayContainer, MinimalGridArrays):
5556
"""Grid object containing the entire network and interface to interact with it.
5657
5758
Examples:

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,12 @@ def fancy_test_array():
6464

6565

6666
@pytest.fixture
67-
def basic_grid(grid):
67+
def basic_grid(grid: Grid):
6868
yield build_basic_grid(grid)
6969

7070

7171
@pytest.fixture
72-
def grid_with_3wt(grid):
72+
def grid_with_3wt(grid: Grid):
7373
yield build_basic_grid_with_three_winding(grid)
7474

7575

tests/integration/loadflow/test_power_grid_model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_load_flow_on_random():
6060
assert all(output["line"]["i_from"] > 0)
6161

6262

63-
def test_load_flow(grid):
63+
def test_load_flow(grid: Grid):
6464
"""Tests the load flow on a test grid with 2 nodes"""
6565
nodes = NodeArray.zeros(2)
6666
nodes.id = [0, 1]
@@ -105,7 +105,7 @@ def test_load_flow(grid):
105105
assert all(output["line"]["i_from"] > 0)
106106

107107

108-
def test_load_flow_with_transformer(grid):
108+
def test_load_flow_with_transformer(grid: Grid):
109109
"""Tests the load flow on a test grid with 3 nodes and a trafo"""
110110
nodes = NodeArray.zeros(3)
111111
nodes.id = [0, 1, 2]
@@ -178,7 +178,7 @@ def test_load_flow_with_transformer(grid):
178178

179179
# pylint: disable=too-many-statements
180180
# pylint: disable=duplicate-code
181-
def test_load_flow_with_three_winding_transformer(grid):
181+
def test_load_flow_with_three_winding_transformer(grid: Grid):
182182
"""Tests the load flow on a test grid with 3 nodes and a three winding trafo"""
183183
nodes = NodeArray.zeros(3)
184184
nodes.id = [0, 1, 2]
@@ -249,7 +249,7 @@ def test_load_flow_with_three_winding_transformer(grid):
249249
assert all(output["three_winding_transformer"]["loading"] > 0)
250250

251251

252-
def test_load_flow_with_link(grid):
252+
def test_load_flow_with_link(grid: Grid):
253253
"""Tests the load flow on a test grid with 2 nodes and a link"""
254254
nodes = NodeArray.zeros(2)
255255
nodes.id = [0, 1]
@@ -293,7 +293,7 @@ def test_load_flow_with_link(grid):
293293
assert all(output["link"]["i_from"] > 0)
294294

295295

296-
def test_automatic_tap_regulator(grid):
296+
def test_automatic_tap_regulator(grid: Grid):
297297
"""Test automatic tap regulator
298298
299299
Network:

tests/unit/data_source/generator/test_grid_generators.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from power_grid_model_ds._core.data_source.generator.grid_generators import RadialGridGenerator
1515
from power_grid_model_ds._core.load_flow import PowerGridModelInterface
1616
from power_grid_model_ds._core.model.arrays import LineArray, NodeArray, SourceArray, SymLoadArray
17+
from power_grid_model_ds._core.model.graphs.models.base import BaseGraphModel
1718
from power_grid_model_ds._core.model.grids.base import Grid
1819

1920

@@ -37,14 +38,14 @@ def test_generate_random_grid():
3738
assert len(grid.line) == grid.graphs.complete_graph.nr_branches
3839

3940

40-
def test_graph_generate_random_grid_with_different_graph_engines(graph):
41+
def test_graph_generate_random_grid_with_different_graph_engines(graph: BaseGraphModel):
4142
"""Generate a random grid with correct structure"""
4243
grid_generator = RadialGridGenerator(grid_class=Grid, graph_model=graph.__class__)
4344
grid = grid_generator.run(seed=0)
4445
assert isinstance(grid.graphs.active_graph, graph.__class__)
4546

4647

47-
def test_generate_random_nodes(grid):
48+
def test_generate_random_nodes(grid: Grid):
4849
"""Generate random nodes"""
4950
node_generator = NodeGenerator(grid, seed=0)
5051
nodes, loads_low, loads_high = node_generator.run(amount=2)
@@ -64,7 +65,7 @@ def test_generate_random_nodes(grid):
6465
assert all(np.isin(loads_low.node, nodes.id))
6566

6667

67-
def test_generate_random_sources(grid):
68+
def test_generate_random_sources(grid: Grid):
6869
"""Generate random sources"""
6970
source_generator = SourceGenerator(grid=grid, seed=0)
7071
nodes, sources = source_generator.run(amount=1)
@@ -81,7 +82,7 @@ def test_generate_random_sources(grid):
8182
assert all(np.isin(sources.node, nodes.id))
8283

8384

84-
def test_generate_random_lines(grid):
85+
def test_generate_random_lines(grid: Grid):
8586
"""Generate random lines"""
8687
nodes = NodeArray.zeros(4)
8788
nodes.id = [0, 1, 2, 3]
@@ -112,7 +113,7 @@ def test_generate_random_lines(grid):
112113
assert all(np.isin(lines.to_node, nodes.id))
113114

114115

115-
def test_create_routes(grid):
116+
def test_create_routes(grid: Grid):
116117
"""Generate new routes"""
117118
nodes = NodeArray.zeros(4)
118119
nodes.id = [0, 1, 2, 3]
@@ -146,7 +147,7 @@ def test_create_routes(grid):
146147
assert all(np.isin(line_generator.line_array.to_node, nodes.id))
147148

148149

149-
def test_determine_number_of_routes(grid):
150+
def test_determine_number_of_routes(grid: Grid):
150151
"""Number of routes"""
151152
line_generator = LineGenerator(grid=grid, seed=0)
152153

@@ -168,7 +169,7 @@ def test_determine_number_of_routes(grid):
168169
assert 3 == number_of_routes
169170

170171

171-
def test_connect_nodes(grid):
172+
def test_connect_nodes(grid: Grid):
172173
"""Connect nodes"""
173174
nodes = NodeArray.zeros(4)
174175
nodes.id = [0, 1, 2, 3]
@@ -205,7 +206,7 @@ def test_connect_nodes(grid):
205206
assert 2 == len(line_generator.line_array)
206207

207208

208-
def test_create_nops(grid):
209+
def test_create_nops(grid: Grid):
209210
"""Create normally open points"""
210211
nodes = NodeArray.zeros(4)
211212
nodes.id = [0, 1, 2, 3]

tests/unit/model/arrays/test_array.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,63 +33,63 @@ class _InheritedStrLengthArray(_CustomStrLengthArray):
3333
_str_lengths = {"extra_string": 100}
3434

3535

36-
def test_get_non_existing_attribute(fancy_test_array):
36+
def test_get_non_existing_attribute(fancy_test_array: FancyTestArray):
3737
with pytest.raises(AttributeError):
3838
# pylint: disable=pointless-statement
3939
fancy_test_array.non_existing_attribute # noqa
4040

4141

42-
def test_for_loop(fancy_test_array):
42+
def test_for_loop(fancy_test_array: FancyTestArray):
4343
for row in fancy_test_array:
4444
assert row in fancy_test_array
4545
assert isinstance(row, FancyTestArray)
4646

4747

48-
def test_setattr(fancy_test_array):
48+
def test_setattr(fancy_test_array: FancyTestArray):
4949
assert_array_equal(fancy_test_array.id, [1, 2, 3])
5050
fancy_test_array.id = [9, 9, 9]
5151

5252
assert_array_equal(fancy_test_array.id, [9, 9, 9])
5353
assert_array_equal(fancy_test_array.data["id"], [9, 9, 9])
5454

5555

56-
def test_prevent_delete_numpy_attribute(fancy_test_array):
56+
def test_prevent_delete_numpy_attribute(fancy_test_array: FancyTestArray):
5757
with pytest.raises(AttributeError):
5858
del fancy_test_array.size # 'size' is a numpy attribute
5959

6060

61-
def test_getitem_array_one_column(fancy_test_array):
61+
def test_getitem_array_one_column(fancy_test_array: FancyTestArray):
6262
assert_array_equal(fancy_test_array["id"], [1, 2, 3])
6363

6464

65-
def test_getitem_array_multiple_columns(fancy_test_array):
65+
def test_getitem_array_multiple_columns(fancy_test_array: FancyTestArray):
6666
columns = ["id", "test_int", "test_float"]
6767
assert fancy_test_array.data[columns].tolist() == fancy_test_array[columns].tolist()
6868
assert_array_equal(fancy_test_array[columns].dtype.names, ("id", "test_int", "test_float"))
6969

7070

71-
def test_getitem_unique_multiple_columns(fancy_test_array):
71+
def test_getitem_unique_multiple_columns(fancy_test_array: FancyTestArray):
7272
columns = ["id", "test_int", "test_float"]
7373
assert np.array_equal(np.unique(fancy_test_array[columns]), fancy_test_array[columns])
7474

7575

76-
def test_getitem_array_slice(fancy_test_array):
76+
def test_getitem_array_slice(fancy_test_array: FancyTestArray):
7777
assert fancy_test_array.data[0:2].tolist() == fancy_test_array[0:2].tolist()
7878

7979

80-
def test_getitem_with_array_mask(fancy_test_array):
80+
def test_getitem_with_array_mask(fancy_test_array: FancyTestArray):
8181
mask = np.array([True, False, True])
8282
assert isinstance(fancy_test_array[mask], FancyArray)
8383
assert np.array_equal(fancy_test_array.data[mask], fancy_test_array[mask].data)
8484

8585

86-
def test_getitem_with_tuple_mask(fancy_test_array):
86+
def test_getitem_with_tuple_mask(fancy_test_array: FancyTestArray):
8787
mask = (True, False, True)
8888
assert isinstance(fancy_test_array[mask], FancyArray)
8989
assert np.array_equal(fancy_test_array.data[mask], fancy_test_array[mask].data)
9090

9191

92-
def test_getitem_with_list_mask(fancy_test_array):
92+
def test_getitem_with_list_mask(fancy_test_array: FancyTestArray):
9393
mask = [True, False, True]
9494
assert isinstance(fancy_test_array[mask], FancyArray)
9595
assert np.array_equal(fancy_test_array.data[mask], fancy_test_array[mask].data)
@@ -102,54 +102,54 @@ def test_getitem_with_empty_list_mask():
102102
assert np.array_equal(array.data[mask], array[mask].data)
103103

104104

105-
def test_setitem_with_index(fancy_test_array):
105+
def test_setitem_with_index(fancy_test_array: FancyTestArray):
106106
fancy_test_array[0] = (9, 9, 9, 9, 9)
107107
assert [9, 2, 3] == fancy_test_array.id.tolist()
108108

109109

110-
def test_setitem_with_mask(fancy_test_array):
110+
def test_setitem_with_mask(fancy_test_array: FancyTestArray):
111111
mask = np.array([True, False, True])
112112
fancy_test_array[mask] = (9, 9, 9, 9, 9)
113113
assert [9, 2, 9] == fancy_test_array.id.tolist()
114114

115115

116-
def test_setitem_as_fancy_array_with_mask(fancy_test_array):
116+
def test_setitem_as_fancy_array_with_mask(fancy_test_array: FancyTestArray):
117117
mask = np.array([True, False, True])
118118
fancy_test_array[mask] = FancyTestArray.zeros(2)
119119
assert_array_equal([EMPTY_ID, 2, EMPTY_ID], fancy_test_array.id)
120120

121121

122-
def test_setitem_as_fancy_array_with_mask_too_large(fancy_test_array):
122+
def test_setitem_as_fancy_array_with_mask_too_large(fancy_test_array: FancyTestArray):
123123
mask = np.array([True, False, True])
124124
with pytest.raises(ValueError):
125125
fancy_test_array[mask] = FancyTestArray.zeros(3)
126126

127127

128-
def test_set_non_existing_field(fancy_test_array):
128+
def test_set_non_existing_field(fancy_test_array: FancyTestArray):
129129
with pytest.raises(AttributeError):
130130
fancy_test_array.non_existing_field = 123
131131

132132

133-
def test_set_callable(fancy_test_array):
133+
def test_set_callable(fancy_test_array: FancyTestArray):
134134
with pytest.raises(AttributeError):
135135
fancy_test_array.filter = 123
136136

137137

138-
def test_contains(fancy_test_array):
138+
def test_contains(fancy_test_array: FancyTestArray):
139139
assert fancy_test_array[0] in fancy_test_array
140140

141141

142-
def test_non_existing_method(fancy_test_array):
142+
def test_non_existing_method(fancy_test_array: FancyTestArray):
143143
with pytest.raises(AttributeError):
144144
# pylint: disable=no-member
145145
fancy_test_array.non_existing_method()
146146

147147

148-
def test_array_equal(fancy_test_array):
148+
def test_array_equal(fancy_test_array: FancyTestArray):
149149
assert fp.array_equal(fancy_test_array, fancy_test_array.copy())
150150

151151

152-
def test_array_not_equal(fancy_test_array):
152+
def test_array_not_equal(fancy_test_array: FancyTestArray):
153153
different_array = fancy_test_array.copy()
154154
different_array.test_int = 99
155155
assert not fp.array_equal(fancy_test_array, different_array)
@@ -203,60 +203,60 @@ def test_is_empty_bool():
203203
assert_array_equal(array.is_empty("test_bool"), [True, False])
204204

205205

206-
def test_unique(fancy_test_array):
206+
def test_unique(fancy_test_array: FancyTestArray):
207207
duplicate_array = fp.concatenate(fancy_test_array, fancy_test_array)
208208
unique_array = fp.unique(duplicate_array)
209209
assert fp.array_equal(unique_array, fancy_test_array)
210210

211211

212-
def test_unique_with_nan_values(fancy_test_array):
212+
def test_unique_with_nan_values(fancy_test_array: FancyTestArray):
213213
fancy_test_array.test_float = np.nan
214214
duplicate_array = fp.concatenate(fancy_test_array, fancy_test_array)
215215
with pytest.raises(NotImplementedError):
216216
fp.unique(duplicate_array)
217217

218218

219-
def test_unique_return_inverse(fancy_test_array):
219+
def test_unique_return_inverse(fancy_test_array: FancyTestArray):
220220
duplicate_array = fp.concatenate(fancy_test_array, fancy_test_array)
221221
unique_array, inverse = fp.unique(duplicate_array, return_inverse=True)
222222
assert fp.array_equal(unique_array, fancy_test_array)
223223
assert_array_equal(inverse, [0, 1, 2, 0, 1, 2])
224224

225225

226-
def test_unique_return_counts_and_inverse(fancy_test_array):
226+
def test_unique_return_counts_and_inverse(fancy_test_array: FancyTestArray):
227227
duplicate_array = fp.concatenate(fancy_test_array, fancy_test_array)
228228
unique_array, inverse, counts = fp.unique(duplicate_array, return_counts=True, return_inverse=True)
229229
assert fp.array_equal(unique_array, fancy_test_array)
230230
assert_array_equal(counts, [2, 2, 2])
231231
assert_array_equal(inverse, [0, 1, 2, 0, 1, 2])
232232

233233

234-
def test_sort(fancy_test_array):
234+
def test_sort(fancy_test_array: FancyTestArray):
235235
assert_array_equal(fancy_test_array.test_float, [4.0, 4.0, 1.0])
236236
fancy_test_array.sort(order="test_float")
237237
assert_array_equal(fancy_test_array.test_float, [1.0, 4.0, 4.0])
238238

239239

240-
def test_copy_function(fancy_test_array):
240+
def test_copy_function(fancy_test_array: FancyTestArray):
241241
array_copy = copy(fancy_test_array)
242242
array_copy.test_int = 123
243243
assert not id(fancy_test_array) == id(array_copy)
244244
assert not fancy_test_array.test_int[0] == array_copy.test_int[0]
245245

246246

247-
def test_copy_method(fancy_test_array):
247+
def test_copy_method(fancy_test_array: FancyTestArray):
248248
array_copy = fancy_test_array.copy()
249249
array_copy.test_int = 123
250250
assert not id(fancy_test_array.data) == id(array_copy.data)
251-
assert not fancy_test_array.test_int[0] == array_copy.test_int[0]
251+
assert not fancy_test_array.test_int[0] == array_copy.test_int[0] # type: ignore
252252

253253

254-
def test_prevent_np_unique_on_fancy_array(fancy_test_array):
254+
def test_prevent_np_unique_on_fancy_array(fancy_test_array: FancyTestArray):
255255
with pytest.raises(TypeError):
256256
np.unique(fancy_test_array)
257257

258258

259-
def test_prevent_np_sort_on_fancy_array(fancy_test_array):
259+
def test_prevent_np_sort_on_fancy_array(fancy_test_array: FancyTestArray):
260260
with pytest.raises(TypeError):
261261
np.sort(fancy_test_array)
262262

@@ -278,7 +278,7 @@ def test_string_inherit_string_length():
278278
assert_array_equal(array.extra_string, ["b" * 100])
279279

280280

281-
def test_shuffle_array(fancy_test_array):
281+
def test_shuffle_array(fancy_test_array: FancyTestArray):
282282
rng = np.random.default_rng(0)
283283
rng.shuffle(fancy_test_array.data)
284284
assert_array_equal(fancy_test_array.id, [3, 1, 2])

0 commit comments

Comments
 (0)