Skip to content

Commit 41eec52

Browse files
authored
Merge pull request #6 from ekomarova/transition-to-scikit-build
Replace test_basic with pytest
2 parents 87005c1 + 2d24480 commit 41eec52

File tree

2 files changed

+21
-20
lines changed

2 files changed

+21
-20
lines changed

conda-recipe/meta.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@ requirements:
3737
- {{ pin_compatible('numpy') }}
3838

3939
test:
40+
requires:
41+
- pytest
4042
source_files:
4143
- mkl_umath/tests/test_basic.py
4244
commands:
43-
- python mkl_umath/tests/test_basic.py
45+
- pytest mkl_umath/tests/test_basic.py
4446
imports:
4547
- mkl_umath
4648
- mkl_umath._ufuncs

mkl_umath/tests/test_basic.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2424
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2525

26+
import pytest
2627
import numpy as np
2728
import mkl_umath._ufuncs as mu
2829
import numpy.core.umath as nu
@@ -49,11 +50,8 @@ def get_args(args_str):
4950
return tuple(args)
5051

5152
umaths = [i for i in dir(mu) if isinstance(getattr(mu, i), np.ufunc)]
52-
5353
umaths.remove('arccosh') # expects input greater than 1
5454

55-
# dictionary with test cases
56-
# (umath, types) : args
5755
generated_cases = {}
5856
for umath in umaths:
5957
mkl_umath = getattr(mu, umath)
@@ -64,29 +62,30 @@ def get_args(args_str):
6462
generated_cases[(umath, type)] = args
6563

6664
additional_cases = {
67-
('arccosh', 'f->f') : (np.single(np.random.random_sample() + 1),),
68-
('arccosh', 'd->d') : (np.double(np.random.random_sample() + 1),),
65+
('arccosh', 'f->f'): (np.single(np.random.random_sample() + 1),),
66+
('arccosh', 'd->d'): (np.double(np.random.random_sample() + 1),),
6967
}
7068

71-
test_cases = {}
72-
for d in (generated_cases, additional_cases):
73-
test_cases.update(d)
69+
test_cases = {**generated_cases, **additional_cases}
7470

75-
for case in test_cases:
76-
umath = case[0]
77-
type = case[1]
71+
@pytest.mark.parametrize("case", list(test_cases.keys()))
72+
def test_umath(case):
73+
umath, type = case
7874
args = test_cases[case]
7975
mkl_umath = getattr(mu, umath)
8076
np_umath = getattr(nu, umath)
8177
print('*'*80)
82-
print(umath, type)
83-
print("args", args)
78+
print(f"Testing {umath} with type {type}")
79+
print("args:", args)
80+
8481
mkl_res = mkl_umath(*args)
8582
np_res = np_umath(*args)
86-
print("mkl res", mkl_res)
87-
print("npy res", np_res)
88-
89-
assert np.allclose(mkl_res, np_res)
83+
84+
print("mkl res:", mkl_res)
85+
print("npy res:", np_res)
86+
87+
assert np.allclose(mkl_res, np_res), f"Results for {umath} do not match"
9088

91-
print("Test cases count:", len(test_cases))
92-
print("All looks good!")
89+
def test_cases_count():
90+
print("Test cases count:", len(test_cases))
91+
assert len(test_cases) > 0, "No test cases found"

0 commit comments

Comments
 (0)