23
23
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24
24
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25
25
26
+ import pytest
26
27
import numpy as np
27
28
import mkl_umath ._ufuncs as mu
28
29
import numpy .core .umath as nu
@@ -49,11 +50,8 @@ def get_args(args_str):
49
50
return tuple (args )
50
51
51
52
umaths = [i for i in dir (mu ) if isinstance (getattr (mu , i ), np .ufunc )]
52
-
53
53
umaths .remove ('arccosh' ) # expects input greater than 1
54
54
55
- # dictionary with test cases
56
- # (umath, types) : args
57
55
generated_cases = {}
58
56
for umath in umaths :
59
57
mkl_umath = getattr (mu , umath )
@@ -64,29 +62,30 @@ def get_args(args_str):
64
62
generated_cases [(umath , type )] = args
65
63
66
64
additional_cases = {
67
- ('arccosh' , 'f->f' ) : (np .single (np .random .random_sample () + 1 ),),
68
- ('arccosh' , 'd->d' ) : (np .double (np .random .random_sample () + 1 ),),
65
+ ('arccosh' , 'f->f' ): (np .single (np .random .random_sample () + 1 ),),
66
+ ('arccosh' , 'd->d' ): (np .double (np .random .random_sample () + 1 ),),
69
67
}
70
68
71
- test_cases = {}
72
- for d in (generated_cases , additional_cases ):
73
- test_cases .update (d )
69
+ test_cases = {** generated_cases , ** additional_cases }
74
70
75
- for case in test_cases :
76
- umath = case [ 0 ]
77
- type = case [ 1 ]
71
+ @ pytest . mark . parametrize ( " case" , list ( test_cases . keys ()))
72
+ def test_umath ( case ):
73
+ umath , type = case
78
74
args = test_cases [case ]
79
75
mkl_umath = getattr (mu , umath )
80
76
np_umath = getattr (nu , umath )
81
77
print ('*' * 80 )
82
- print (umath , type )
83
- print ("args" , args )
78
+ print (f"Testing { umath } with type { type } " )
79
+ print ("args:" , args )
80
+
84
81
mkl_res = mkl_umath (* args )
85
82
np_res = np_umath (* args )
86
- print ("mkl res" , mkl_res )
87
- print ("npy res" , np_res )
88
-
89
- assert np .allclose (mkl_res , np_res )
83
+
84
+ print ("mkl res:" , mkl_res )
85
+ print ("npy res:" , np_res )
86
+
87
+ assert np .allclose (mkl_res , np_res ), f"Results for { umath } do not match"
90
88
91
- print ("Test cases count:" , len (test_cases ))
92
- print ("All looks good!" )
89
+ def test_cases_count ():
90
+ print ("Test cases count:" , len (test_cases ))
91
+ assert len (test_cases ) > 0 , "No test cases found"
0 commit comments