@@ -991,16 +991,19 @@ def check(dtype, N, M_=None, k=0):
991991 if M is None and config .mode in ["DebugMode" , "DEBUG_MODE" ]:
992992 M = N
993993 N_symb = iscalar ()
994- M_symb = iscalar ()
995- k_symb = iscalar ()
996- f = function (
997- [N_symb , M_symb , k_symb ], tri (N_symb , M_symb , k_symb , dtype = dtype )
998- )
999- result = f (N , M , k )
994+ f = function ([N_symb ], tri (N_symb , M = M , k = k , dtype = dtype ))
995+ # kwargs = {}
996+ result = f (N )
1000997 assert np .allclose (result , np .tri (N , M_ , k , dtype = dtype ))
1001998 assert result .dtype == np .dtype (dtype )
1002999
1003- for dtype in ["int32" , "int64" , "float32" , "float64" , "uint16" , "complex64" ]:
1000+ for dtype in [
1001+ "int32" ,
1002+ "int64" ,
1003+ "float32" ,
1004+ "float64" ,
1005+ "uint16" ,
1006+ ]: # Handle "complex64" ?
10041007 check (dtype , 3 )
10051008 # M != N, k = 0
10061009 check (dtype , 3 , 5 )
@@ -1022,15 +1025,15 @@ def test_tril_triu(self):
10221025
10231026 def check_l (m , k = 0 ):
10241027 m_symb = matrix (dtype = m .dtype )
1025- k_symb = iscalar ()
1026- f = function ([m_symb , k_symb ], tril (m_symb , k_symb ))
1028+ # k_symb = iscalar()
1029+ f = function ([m_symb ], tril (m_symb , k = k ))
10271030 f_indx = function (
1028- [m_symb , k_symb ], tril_indices (m_symb .shape [0 ], k_symb , m_symb .shape [1 ])
1031+ [m_symb ], tril_indices (m_symb .shape [0 ], k = k , m = m_symb .shape [1 ])
10291032 )
1030- f_indx_from = function ([m_symb , k_symb ], tril_indices_from (m_symb , k_symb ))
1031- result = f (m , k )
1032- result_indx = f_indx (m , k )
1033- result_from = f_indx_from (m , k )
1033+ f_indx_from = function ([m_symb ], tril_indices_from (m_symb ))
1034+ result = f (m )
1035+ result_indx = f_indx (m , k = k )
1036+ result_from = f_indx_from (m , k = k )
10341037 assert np .allclose (result , np .tril (m , k ))
10351038 assert np .allclose (result_indx , np .tril_indices (m .shape [0 ], k , m .shape [1 ]))
10361039 assert np .allclose (result_from , np .tril_indices_from (m , k ))
@@ -1040,7 +1043,7 @@ def check_l(m, k=0):
10401043 def check_u (m , k = 0 ):
10411044 m_symb = matrix (dtype = m .dtype )
10421045 k_symb = iscalar ()
1043- f = function ([m_symb , k_symb ], triu (m_symb , k_symb ))
1046+ f = function ([m_symb , k_symb ], triu (m_symb , k = k ))
10441047 f_indx = function (
10451048 [m_symb , k_symb ], triu_indices (m_symb .shape [0 ], k_symb , m_symb .shape [1 ])
10461049 )
0 commit comments