@@ -5630,7 +5630,7 @@ def matmul_kernel( #
56305630 stride_cm , stride_cn , #
56315631 BLOCK_SIZE_M : tl .constexpr , BLOCK_SIZE_N : tl .constexpr , BLOCK_SIZE_K : tl .constexpr , #
56325632 low_precision_acc : tl .constexpr , #
5633- num_pipeline_stages : tl .constexpr = 3 #
5633+ num_stages : tl .constexpr = 3 #
56345634):
56355635 pid = tl .program_id (axis = 0 )
56365636 num_pid_m = tl .cdiv (M , BLOCK_SIZE_M )
@@ -5642,7 +5642,7 @@ def matmul_kernel( #
56425642 a_ptrs = a_ptr + (offs_am [:, None ] * stride_am + offs_k [None , :] * stride_ak )
56435643 b_ptrs = b_ptr + (offs_k [:, None ] * stride_bk + offs_bn [None , :] * stride_bn )
56445644 accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = tl .float32 )
5645- for k in tl .range (0 , tl .cdiv (K , BLOCK_SIZE_K ), num_stages = num_pipeline_stages ):
5645+ for k in tl .range (0 , tl .cdiv (K , BLOCK_SIZE_K ), num_stages = num_stages ):
56465646 a = tl .load (a_ptrs )
56475647 b = tl .load (b_ptrs )
56485648 accumulator = tl .dot (a , b , acc = accumulator , max_num_imprecise_acc = low_precision_acc )
@@ -5681,7 +5681,7 @@ def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_s
56815681 max_num_impressive_acc = low_precision_acc if low_precision_acc <= BLOCK_K else None
56825682 h = matmul_kernel [grid ](a , b , C , M , N , K , a .stride (0 ), a .stride (1 ), b .stride (0 ), b .stride (1 ), C .stride (0 ),
56835683 C .stride (1 ), BLOCK_M , BLOCK_N , BLOCK_K , max_num_impressive_acc , num_warps = num_warps ,
5684- num_pipeline_stages = num_stages )
5684+ num_stages = num_stages )
56855685 torch_a = torch .from_numpy (A ).to (device = device )
56865686 th_a = f8_to_f16 (torch_a , in_type_str )
56875687 torch_b = torch .from_numpy (B ).to (device = device )
@@ -5873,7 +5873,7 @@ def test_tl_range(device):
58735873 pgm = matmul_kernel [
58745874 1 ,
58755875 ](a , b , c , M , N , K , a .stride (0 ), a .stride (1 ), b .stride (0 ), b .stride (1 ), c .stride (0 ), c .stride (1 ), BLOCK_M , BLOCK_N ,
5876- BLOCK_K , 0 , num_pipeline_stages = 5 )
5876+ BLOCK_K , 0 , num_stages = 5 )
58775877 ref_out = torch .matmul (a , b ).to (torch .float32 )
58785878 if is_interpreter ():
58795879 # GPU invokes tensor core for float16 matmul, which is not supported in interpreter.
@@ -5899,8 +5899,8 @@ def maxnreg_noinline2(X):
58995899 tl .store (X , 0 )
59005900
59015901
5902+ @pytest .mark .interpreter
59025903def test_maxnreg (device ):
5903- assert not is_interpreter (), "this test won't work with the interpreter"
59045904 if not is_cuda ():
59055905 pytest .xfail ('maxnreg only works on CUDA' )
59065906
@@ -5914,14 +5914,15 @@ def kernel(X):
59145914 X = torch .empty (1 , dtype = torch .int32 , device = device )
59155915 k = kernel [(1 , )](X , maxnreg = 42 )
59165916
5917- # Ensure that .maxnreg is set on the kernel function (marked with .entry)
5918- # and not on either of the noinline functions (marked with .func).
5919- try :
5920- assert re .search (r'\.visible \.entry [^{;]*\.maxnreg 42' , k .asm ["ptx" ])
5921- assert not re .search (r'\.visible \.func [^{;]*\.maxnreg' , k .asm ["ptx" ])
5922- except AssertionError :
5923- print ("Failing ptx:\n " , k .asm ["ptx" ])
5924- raise
5917+ if not is_interpreter ():
5918+ # Ensure that .maxnreg is set on the kernel function (marked with .entry)
5919+ # and not on either of the noinline functions (marked with .func).
5920+ try :
5921+ assert re .search (r'\.visible \.entry [^{;]*\.maxnreg 42' , k .asm ["ptx" ])
5922+ assert not re .search (r'\.visible \.func [^{;]*\.maxnreg' , k .asm ["ptx" ])
5923+ except AssertionError :
5924+ print ("Failing ptx:\n " , k .asm ["ptx" ])
5925+ raise
59255926
59265927
59275928@pytest .mark .interpreter
0 commit comments