3535 key = ['M' , 'N' , 'K' ],
3636)
3737@triton .jit
38- def matmul_kernel_with_block_pointers (
38+ def matmul_kernel_with_tensor_descriptors (
3939 # Pointers to matrices
4040 a_ptr , b_ptr , c_ptr ,
4141 # Matrix dimensions
@@ -56,29 +56,26 @@ def matmul_kernel_with_block_pointers(
5656 pid_m = first_pid_m + ((pid % num_pid_in_group ) % group_size_m )
5757 pid_n = (pid % num_pid_in_group ) // group_size_m
5858
59- a_block_ptr = tl .make_block_ptr (base = a_ptr , shape = (M , K ), strides = (stride_am , stride_ak ),
60- offsets = (pid_m * BLOCK_SIZE_M , 0 ), block_shape = (BLOCK_SIZE_M , BLOCK_SIZE_K ),
61- order = (1 , 0 ))
62- b_block_ptr = tl .make_block_ptr (base = b_ptr , shape = (K , N ), strides = (stride_bk , stride_bn ),
63- offsets = (0 , pid_n * BLOCK_SIZE_N ), block_shape = (BLOCK_SIZE_K , BLOCK_SIZE_N ),
64- order = (1 , 0 ))
59+ a_desc = tl .make_tensor_descriptor (base = a_ptr , shape = (M , K ), strides = (stride_am , stride_ak ),
60+ block_shape = (BLOCK_SIZE_M , BLOCK_SIZE_K ))
61+ b_desc = tl .make_tensor_descriptor (base = b_ptr , shape = (K , N ), strides = (stride_bk , stride_bn ),
62+ block_shape = (BLOCK_SIZE_K , BLOCK_SIZE_N ))
6563
6664 accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = tl .float32 )
65+ off_k = 0
6766 for _ in range (0 , K , BLOCK_SIZE_K ):
68- a = tl .load (a_block_ptr , boundary_check = ( 0 , 1 ) )
67+ a = a_desc .load ([ pid_m * BLOCK_SIZE_M , off_k ] )
6968 a = a .to (tl .float32 )
7069 a = tl .math .exp (a )
7170 a = a .to (tl .bfloat16 )
72- b = tl .load (b_block_ptr , boundary_check = ( 0 , 1 ) )
71+ b = b_desc .load ([ off_k , pid_n * BLOCK_SIZE_N ] )
7372 accumulator += tl .dot (a , b )
74- a_block_ptr = tl .advance (a_block_ptr , (0 , BLOCK_SIZE_K ))
75- b_block_ptr = tl .advance (b_block_ptr , (BLOCK_SIZE_K , 0 ))
73+ off_k += BLOCK_SIZE_K
7674 c = accumulator .to (tl .float32 )
7775
78- c_block_ptr = tl .make_block_ptr (base = c_ptr , shape = (M , N ), strides = (stride_cm , stride_cn ),
79- offsets = (pid_m * BLOCK_SIZE_M , pid_n * BLOCK_SIZE_N ),
80- block_shape = (BLOCK_SIZE_M , BLOCK_SIZE_N ), order = (1 , 0 ))
81- tl .store (c_block_ptr , c , boundary_check = (0 , 1 ))
76+ c_desc = tl .make_tensor_descriptor (base = c_ptr , shape = (M , N ), strides = (stride_cm , stride_cn ),
77+ block_shape = (BLOCK_SIZE_M , BLOCK_SIZE_N ))
78+ c_desc .store ([pid_m * BLOCK_SIZE_M , pid_n * BLOCK_SIZE_N ], c )
8279
8380
8481# pylint: disable=unused-argument
@@ -106,7 +103,7 @@ def matmul_kernel_with_block_pointers(
106103 key = ['M' , 'N' , 'K' ],
107104)
108105@triton .jit
109- def matmul_kernel_with_block_pointers_batched (
106+ def matmul_kernel_with_tensor_descriptors_batched (
110107 # Pointers to matrices
111108 a_ptr , b_ptr , c_ptr ,
112109 # Matrix dimensions
@@ -131,30 +128,27 @@ def matmul_kernel_with_block_pointers_batched(
131128 offset_a = bid .to (tl .int64 ) * stride_az
132129 offset_b = bid .to (tl .int64 ) * stride_bz
133130
134- a_block_ptr = tl .make_block_ptr (base = a_ptr + offset_a , shape = (M , K ), strides = (stride_am , stride_ak ),
135- offsets = (pid_m * BLOCK_SIZE_M , 0 ), block_shape = (BLOCK_SIZE_M , BLOCK_SIZE_K ),
136- order = (1 , 0 ))
137- b_block_ptr = tl .make_block_ptr (base = b_ptr + offset_b , shape = (K , N ), strides = (stride_bk , stride_bn ),
138- offsets = (0 , pid_n * BLOCK_SIZE_N ), block_shape = (BLOCK_SIZE_K , BLOCK_SIZE_N ),
139- order = (1 , 0 ))
131+ a_desc = tl .make_tensor_descriptor (base = a_ptr + offset_a , shape = (M , K ), strides = (stride_am , stride_ak ),
132+ block_shape = (BLOCK_SIZE_M , BLOCK_SIZE_K ))
133+ b_desc = tl .make_tensor_descriptor (base = b_ptr + offset_b , shape = (K , N ), strides = (stride_bk , stride_bn ),
134+ block_shape = (BLOCK_SIZE_K , BLOCK_SIZE_N ))
140135
141136 accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = tl .float32 )
137+ off_k = 0
142138 for _ in range (0 , K , BLOCK_SIZE_K ):
143- a = tl .load (a_block_ptr , boundary_check = ( 0 , 1 ) )
139+ a = a_desc .load ([ pid_m * BLOCK_SIZE_M , off_k ] )
144140 a = a .to (tl .float32 )
145141 a = tl .math .exp (a )
146142 a = a .to (tl .bfloat16 )
147- b = tl .load (b_block_ptr , boundary_check = ( 0 , 1 ) )
143+ b = b_desc .load ([ off_k , pid_n * BLOCK_SIZE_N ] )
148144 accumulator += tl .dot (a , b )
149- a_block_ptr = tl .advance (a_block_ptr , (0 , BLOCK_SIZE_K ))
150- b_block_ptr = tl .advance (b_block_ptr , (BLOCK_SIZE_K , 0 ))
145+ off_k += BLOCK_SIZE_K
151146 c = accumulator .to (tl .float32 )
152147
153148 offset_c = bid .to (tl .int64 ) * stride_cz
154- c_block_ptr = tl .make_block_ptr (base = c_ptr + offset_c , shape = (M , N ), strides = (stride_cm , stride_cn ),
155- offsets = (pid_m * BLOCK_SIZE_M , pid_n * BLOCK_SIZE_N ),
156- block_shape = (BLOCK_SIZE_M , BLOCK_SIZE_N ), order = (1 , 0 ))
157- tl .store (c_block_ptr , c , boundary_check = (0 , 1 ))
149+ c_desc = tl .make_tensor_descriptor (base = c_ptr + offset_c , shape = (M , N ), strides = (stride_cm , stride_cn ),
150+ block_shape = (BLOCK_SIZE_M , BLOCK_SIZE_N ))
151+ c_desc .store ([pid_m * BLOCK_SIZE_M , pid_n * BLOCK_SIZE_N ], c )
158152
159153
160154# We can now create a convenience wrapper function that only takes two input tensors,
@@ -173,7 +167,7 @@ def matmul(a, b, c):
173167 triton .cdiv (M , META ['BLOCK_SIZE_M' ]) * triton .cdiv (N , META ['BLOCK_SIZE_N' ]),
174168 B ,
175169 )
176- matmul_kernel_with_block_pointers_batched [grid ](
170+ matmul_kernel_with_tensor_descriptors_batched [grid ](
177171 a , b , c , #
178172 B , M , N , K , #
179173 a .stride (0 ), a .stride (1 ), a .stride (2 ), #
@@ -186,7 +180,7 @@ def matmul(a, b, c):
186180 M , K = a .shape
187181 K , N = b .shape
188182 grid = lambda META : (triton .cdiv (M , META ['BLOCK_SIZE_M' ]) * triton .cdiv (N , META ['BLOCK_SIZE_N' ]), )
189- matmul_kernel_with_block_pointers [grid ](
183+ matmul_kernel_with_tensor_descriptors [grid ](
190184 a , b , c , #
191185 M , N , K , #
192186 a .stride (0 ), a .stride (1 ), #
0 commit comments