Skip to content

Commit 3c07cbe

Browse files
Merge pull request #97 from alliander-opensource/bugfix/batch-validation-asym-loads
Bugfix: batch validation asym loads
2 parents d6c6d68 + ba2e3e4 commit 3c07cbe

File tree

3 files changed

+48
-4
lines changed

3 files changed

+48
-4
lines changed

src/power_grid_model/validation/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ def update_input_data(input_data: Dict[str, np.ndarray], update_data: Dict[str,
215215
mask = ~np.isnan(array[field])
216216
else:
217217
mask = np.not_equal(array[field], nan)
218+
if mask.ndim == 2:
219+
mask = np.any(mask, axis=1)
218220
data = array[["id", field]][mask]
219221
idx = np.where(merged_data[component]["id"] == np.reshape(data["id"], (-1, 1)))
220222
if isinstance(idx, tuple):

tests/unit/validation/test_batch_validation.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66

77
import numpy as np
88
import pytest
9-
from power_grid_model import initialize_array
10-
from power_grid_model.validation import validate_batch_data
9+
10+
from power_grid_model import CalculationType, LoadGenType, initialize_array
11+
from power_grid_model.validation import validate_batch_data, validate_input_data
1112
from power_grid_model.validation.errors import MultiComponentNotUniqueError, NotBooleanError
1213

1314

@@ -28,15 +29,34 @@ def input_data() -> Dict[str, np.ndarray]:
2829
line["c1"] = 3.0
2930
line["tan1"] = 4.0
3031
line["i_n"] = 5.0
31-
return {"node": node, "line": line}
32+
33+
asym_load = initialize_array("input", "asym_load", 2)
34+
asym_load["id"] = [9, 10]
35+
asym_load["node"] = [1, 2]
36+
asym_load["status"] = [1, 1]
37+
asym_load["type"] = [LoadGenType.const_power, LoadGenType.const_power]
38+
asym_load["p_specified"] = [[11e6, 12e6, 13e6], [21e6, 22e6, 23e6]]
39+
asym_load["q_specified"] = [[11e5, 12e5, 13e5], [21e5, 22e5, 23e5]]
40+
41+
return {"node": node, "line": line, "asym_load": asym_load}
3242

3343

3444
@pytest.fixture
3545
def batch_data() -> Dict[str, np.ndarray]:
3646
line = initialize_array("update", "line", (3, 2))
3747
line["id"] = [[5, 6], [6, 7], [7, 5]]
3848
line["from_status"] = [[1, 1], [1, 1], [1, 1]]
39-
return {"line": line}
49+
50+
# Add batch for asym_load, which has 2-D array for p_specified
51+
asym_load = initialize_array("update", "asym_load", (3, 2))
52+
asym_load["id"] = [[9, 10], [9, 10], [9, 10]]
53+
54+
return {"line": line, "asym_load": asym_load}
55+
56+
57+
def test_validate_batch_data(input_data, batch_data):
58+
errors = validate_batch_data(input_data, batch_data)
59+
assert not errors
4060

4161

4262
def test_validate_batch_data_input_error(input_data, batch_data):

tests/unit/validation/test_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,28 @@ def test_update_input_data_int_nan():
386386
np.testing.assert_array_equal(merged["line"]["from_status"], [0, -128, 1])
387387

388388

389+
def test_update_input_data_asym_nans():
390+
input_load = initialize_array("input", "asym_load", 3)
391+
input_load["id"] = [1, 2, 3]
392+
input_load["p_specified"] = [[1.1, 1.2, 1.3], [2.1, np.nan, np.nan], [np.nan, np.nan, np.nan]]
393+
394+
update_load = initialize_array("update", "asym_load", 3)
395+
update_load["id"] = [1, 2, 3]
396+
update_load["p_specified"] = [[np.nan, np.nan, np.nan], [np.nan, np.nan, 5.3], [6.1, 6.2, 6.3]]
397+
398+
merged = update_input_data(input_data={"asym_load": input_load}, update_data={"asym_load": update_load})
399+
400+
# The desired result would be to update all non-NaN values individually:
401+
# np.testing.assert_array_equal(
402+
# merged["asym_load"]["p_specified"], [[1.1, 1.2, 1.3], [2.1, np.nan, 5.3], [6.1, 6.2, 6.3]]
403+
# )
404+
405+
# The current C++ implementation updates the entire 3-phase value is one of the elements is non-NaN:
406+
np.testing.assert_array_equal(
407+
merged["asym_load"]["p_specified"], [[1.1, 1.2, 1.3], [np.nan, np.nan, 5.3], [6.1, 6.2, 6.3]]
408+
)
409+
410+
389411
def test_errors_to_string_no_errors():
390412
assert errors_to_string(errors=None) == "the data: OK"
391413
assert errors_to_string(errors=[]) == "the data: OK"

0 commit comments

Comments
 (0)