@@ -26,7 +26,7 @@ def test_maximum_minium(dtype, op, device):
2626
2727
2828@pytest .mark .interpreter
29- @pytest .mark .parametrize ("M, N" , [[1 , 512 ], [8 , 64 ], [256 , 16 ], [512 , 8 ]])
29+ @pytest .mark .parametrize ("M, N" , [[1 , 1 ], [ 1 , 512 ], [8 , 64 ], [256 , 16 ], [512 , 8 ]])
3030@pytest .mark .parametrize ("k" , [None , 8 ])
3131@pytest .mark .parametrize ("descending" , [False , True ])
3232@pytest .mark .parametrize ("dtype_str" , ['int32' , 'float16' , 'float32' , 'bfloat16' ])
@@ -40,7 +40,7 @@ def sort_kernel(X, stride_xm, Z, stride_zm, M: tl.constexpr, N: tl.constexpr, k:
4040 offs_z_n = offs_x_n if k is None else tl .arange (0 , k )
4141 offs_x = offs_m [:, None ] * stride_xm + offs_x_n [None , :]
4242 x = tl .load (X + offs_x )
43- if k is None :
43+ if k is None or x . numel < k :
4444 z = tl .sort (x , descending = descending )
4545 else :
4646 z = tl .topk (x , k )
@@ -51,7 +51,7 @@ def sort_kernel(X, stride_xm, Z, stride_zm, M: tl.constexpr, N: tl.constexpr, k:
5151 x = numpy_random ((M , N ), dtype_str = dtype_str )
5252 x = torch .from_numpy (x ).to (device )
5353 z = torch .empty (z_shape , dtype = x .dtype , device = x .device )
54- if k is None :
54+ if k is None or x . numel () < k :
5555 y = torch .sort (x , descending = descending )[0 ]
5656 else :
5757 y = torch .topk (x , k = k ).values
0 commit comments