Skip to content

Commit 978b152

Browse files
committed
refactor(test_parameters): removes duplication of ndarray checks via pytest fixture
1 parent 28db6dd commit 978b152

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import numpy as np
2+
import pytest
3+
4+
5+
@pytest.fixture
6+
def assert_is_ndarray():
7+
"""Recursively assert that all items in a structure are numpy arrays."""
8+
9+
def _assert(obj):
10+
if isinstance(obj, list | tuple):
11+
for item in obj:
12+
_assert(item)
13+
else:
14+
assert isinstance(obj, np.ndarray)
15+
16+
return _assert

tests/unit/test_parameters/test_process_parameter_data.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from pathlib import Path
66

7-
import numpy as np
87
import pytest
98

109
import pybamm
@@ -34,18 +33,13 @@ def test_processed_name(self, parameter_data):
3433
name, processed = parameter_data
3534
assert processed[0] == name
3635

37-
def test_processed_structure(self, parameter_data):
38-
name, processed = parameter_data
39-
assert isinstance(processed[1], tuple)
40-
assert isinstance(processed[1][0][0], np.ndarray)
41-
assert isinstance(processed[1][1], np.ndarray)
36+
def test_processed_structure(self, parameter_data, assert_is_ndarray):
37+
_, processed = parameter_data
4238

43-
if len(processed[1][0]) > 1:
44-
assert isinstance(processed[1][0][1], np.ndarray)
39+
assert isinstance(processed[1], tuple)
4540

46-
elif len(processed[1]) == 3:
47-
assert isinstance(processed[1][0][1], np.ndarray)
48-
assert isinstance(processed[1][0][2], np.ndarray)
41+
# Recursively check that all numpy arrays exist where expected
42+
assert_is_ndarray(processed[1])
4943

5044
def test_error(self):
5145
with pytest.raises(FileNotFoundError, match="Could not find file"):

0 commit comments

Comments
 (0)