@@ -5581,7 +5581,7 @@ def matmul_kernel( #
55815581 stride_cm , stride_cn , #
55825582 BLOCK_SIZE_M : tl .constexpr , BLOCK_SIZE_N : tl .constexpr , BLOCK_SIZE_K : tl .constexpr , #
55835583 low_precision_acc : tl .constexpr , #
5584- num_pipeline_stages : tl .constexpr = 3 #
5584+ num_stages : tl .constexpr = 3 #
55855585):
55865586 pid = tl .program_id (axis = 0 )
55875587 num_pid_m = tl .cdiv (M , BLOCK_SIZE_M )
@@ -5593,7 +5593,7 @@ def matmul_kernel( #
55935593 a_ptrs = a_ptr + (offs_am [:, None ] * stride_am + offs_k [None , :] * stride_ak )
55945594 b_ptrs = b_ptr + (offs_k [:, None ] * stride_bk + offs_bn [None , :] * stride_bn )
55955595 accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = tl .float32 )
5596- for k in tl .range (0 , tl .cdiv (K , BLOCK_SIZE_K ), num_stages = num_pipeline_stages ):
5596+ for k in tl .range (0 , tl .cdiv (K , BLOCK_SIZE_K ), num_stages = num_stages ):
55975597 a = tl .load (a_ptrs )
55985598 b = tl .load (b_ptrs )
55995599 accumulator = tl .dot (a , b , acc = accumulator , max_num_imprecise_acc = low_precision_acc )
@@ -5632,7 +5632,7 @@ def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_s
56325632 max_num_impressive_acc = low_precision_acc if low_precision_acc <= BLOCK_K else None
56335633 h = matmul_kernel [grid ](a , b , C , M , N , K , a .stride (0 ), a .stride (1 ), b .stride (0 ), b .stride (1 ), C .stride (0 ),
56345634 C .stride (1 ), BLOCK_M , BLOCK_N , BLOCK_K , max_num_impressive_acc , num_warps = num_warps ,
5635- num_pipeline_stages = num_stages )
5635+ num_stages = num_stages )
56365636 torch_a = torch .from_numpy (A ).to (device = device )
56375637 th_a = f8_to_f16 (torch_a , in_type_str )
56385638 torch_b = torch .from_numpy (B ).to (device = device )
@@ -5824,7 +5824,7 @@ def test_tl_range(device):
58245824 pgm = matmul_kernel [
58255825 1 ,
58265826 ](a , b , c , M , N , K , a .stride (0 ), a .stride (1 ), b .stride (0 ), b .stride (1 ), c .stride (0 ), c .stride (1 ), BLOCK_M , BLOCK_N ,
5827- BLOCK_K , 0 , num_pipeline_stages = 5 )
5827+ BLOCK_K , 0 , num_stages = 5 )
58285828 ref_out = torch .matmul (a , b ).to (torch .float32 )
58295829 if is_interpreter ():
58305830 # GPU invokes tensor core for float16 matmul, which is not supported in interpreter.
@@ -5850,8 +5850,8 @@ def maxnreg_noinline2(X):
58505850 tl .store (X , 0 )
58515851
58525852
5853+ @pytest .mark .interpreter
58535854def test_maxnreg (device ):
5854- assert not is_interpreter (), "this test won't work with the interpreter"
58555855 if not is_cuda ():
58565856 pytest .skip ('maxnreg only works on CUDA' )
58575857
@@ -5865,14 +5865,15 @@ def kernel(X):
58655865 X = torch .empty (1 , dtype = torch .int32 , device = device )
58665866 k = kernel [(1 , )](X , maxnreg = 42 )
58675867
5868- # Ensure that .maxnreg is set on the kernel function (marked with .entry)
5869- # and not on either of the noinline functions (marked with .func).
5870- try :
5871- assert re .search (r'\.visible \.entry [^{;]*\.maxnreg 42' , k .asm ["ptx" ])
5872- assert not re .search (r'\.visible \.func [^{;]*\.maxnreg' , k .asm ["ptx" ])
5873- except AssertionError :
5874- print ("Failing ptx:\n " , k .asm ["ptx" ])
5875- raise
5868+ if not is_interpreter ():
5869+ # Ensure that .maxnreg is set on the kernel function (marked with .entry)
5870+ # and not on either of the noinline functions (marked with .func).
5871+ try :
5872+ assert re .search (r'\.visible \.entry [^{;]*\.maxnreg 42' , k .asm ["ptx" ])
5873+ assert not re .search (r'\.visible \.func [^{;]*\.maxnreg' , k .asm ["ptx" ])
5874+ except AssertionError :
5875+ print ("Failing ptx:\n " , k .asm ["ptx" ])
5876+ raise
58765877
58775878
58785879@pytest .mark .interpreter
0 commit comments