Skip to content

Commit bdb04da

Browse files
tests
1 parent 9ab852e commit bdb04da

File tree

1 file changed

+66
-2
lines changed

1 file changed

+66
-2
lines changed

mkl_umath/tests/test_basic.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,67 @@
1+
import numpy as np
2+
import mkl_umath._ufuncs as mu
3+
import numpy.core.umath as nu
14

2-
def test1():
3-
pass
5+
np.random.seed(42)
6+
7+
def get_args(args_str):
8+
args = []
9+
for s in args_str:
10+
if s == 'f':
11+
args.append(np.single(np.random.random_sample()))
12+
elif s == 'd':
13+
args.append(np.double(np.random.random_sample()))
14+
elif s == 'F':
15+
args.append(np.single(np.random.random_sample()) + np.single(np.random.random_sample()) * 1j)
16+
elif s == 'D':
17+
args.append(np.double(np.random.random_sample()) + np.double(np.random.random_sample()) * 1j)
18+
elif s == 'i':
19+
args.append(np.int(np.random.randint(low=1, high=10)))
20+
elif s == 'l':
21+
args.append(np.long(np.random.randint(low=1, high=10)))
22+
else:
23+
raise ValueError("Unexpected type specified!")
24+
return tuple(args)
25+
26+
umaths = [i for i in dir(mu) if isinstance(getattr(mu, i), np.ufunc)]
27+
28+
umaths.remove('arccosh') # expects input greater than 1
29+
30+
# dictionary with test cases
31+
# (umath, types) : args
32+
generated_cases = {}
33+
for umath in umaths:
34+
mkl_umath = getattr(mu, umath)
35+
types = mkl_umath.types
36+
for type in types:
37+
args_str = type[:type.find('->')]
38+
args = get_args(args_str)
39+
generated_cases[(umath, type)] = args
40+
41+
additional_cases = {
42+
('arccosh', 'f->f') : (np.single(np.random.random_sample() + 1),),
43+
('arccosh', 'd->d') : (np.double(np.random.random_sample() + 1),),
44+
}
45+
46+
test_cases = {}
47+
for d in (generated_cases, additional_cases):
48+
test_cases.update(d)
49+
50+
for case in test_cases:
51+
umath = case[0]
52+
type = case[1]
53+
args = test_cases[case]
54+
mkl_umath = getattr(mu, umath)
55+
np_umath = getattr(nu, umath)
56+
print('*'*80)
57+
print(umath, type)
58+
print("args", args)
59+
mkl_res = mkl_umath(*args)
60+
np_res = np_umath(*args)
61+
print("mkl res", mkl_res)
62+
print("npy res", np_res)
63+
64+
assert np.array_equal(mkl_res, np_res)
65+
66+
print("Test cases count:", len(test_cases))
67+
print("All looks good!")

0 commit comments

Comments
 (0)