@@ -55,7 +55,7 @@ def suffix():
5555 key = ['M' , 'N' , 'K' ],
5656)
5757@triton .jit
58- def matmul_kernel_with_block_pointers (
58+ def matmul_kernel_with_tensor_descriptors (
5959 # Pointers to matrices
6060 a_ptr , b_ptr , c_ptr , d_ptr ,
6161 # Matrix dimensions
@@ -78,31 +78,27 @@ def matmul_kernel_with_block_pointers(
7878 pid_m = first_pid_m + ((pid % num_pid_in_group ) % group_size_m )
7979 pid_n = (pid % num_pid_in_group ) // group_size_m
8080
81- a_block_ptr = tl .make_block_ptr (base = a_ptr , shape = (M , K ), strides = (stride_am , stride_ak ),
82- offsets = (pid_m * BLOCK_SIZE_M , 0 ), block_shape = (BLOCK_SIZE_M , BLOCK_SIZE_K ),
83- order = (1 , 0 ))
84- b_block_ptr = tl .make_block_ptr (base = b_ptr , shape = (K , N ), strides = (stride_bk , stride_bn ),
85- offsets = (0 , pid_n * BLOCK_SIZE_N ), block_shape = (BLOCK_SIZE_K , BLOCK_SIZE_N ),
86- order = (1 , 0 ))
81+ a_desc = tl .make_tensor_descriptor (base = a_ptr , shape = (M , K ), strides = (stride_am , stride_ak ),
82+ block_shape = (BLOCK_SIZE_M , BLOCK_SIZE_K ))
83+ b_desc = tl .make_tensor_descriptor (base = b_ptr , shape = (K , N ), strides = (stride_bk , stride_bn ),
84+ block_shape = (BLOCK_SIZE_K , BLOCK_SIZE_N ))
8785
8886 accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = ACCUMULATOR_DTYPE )
87+ off_k = 0
8988 for _ in range (0 , K , BLOCK_SIZE_K ):
90- a = tl .load (a_block_ptr , boundary_check = ( 0 , 1 ) )
91- b = tl .load (b_block_ptr , boundary_check = ( 0 , 1 ) )
89+ a = a_desc .load ([ pid_m * BLOCK_SIZE_M , off_k ] )
90+ b = b_desc .load ([ off_k , pid_n * BLOCK_SIZE_N ] )
9291 accumulator += tl .dot (a , b )
93- a_block_ptr = tl .advance (a_block_ptr , (0 , BLOCK_SIZE_K ))
94- b_block_ptr = tl .advance (b_block_ptr , (BLOCK_SIZE_K , 0 ))
92+ off_k += BLOCK_SIZE_K
9593
96- d_block_ptr = tl .make_block_ptr (base = d_ptr , shape = (M , N ), strides = (stride_dm , stride_dn ),
97- offsets = (pid_m * BLOCK_SIZE_M , pid_n * BLOCK_SIZE_N ),
98- block_shape = (BLOCK_SIZE_M , BLOCK_SIZE_N ), order = (1 , 0 ))
99- d = tl .load (d_block_ptr , boundary_check = (0 , 1 ))
94+ d_desc = tl .make_tensor_descriptor (base = d_ptr , shape = (M , N ), strides = (stride_dm , stride_dn ),
95+ block_shape = (BLOCK_SIZE_M , BLOCK_SIZE_N ))
96+ d = d_desc .load ([pid_m * BLOCK_SIZE_M , pid_n * BLOCK_SIZE_N ])
10097 c = accumulator + d
10198
102- c_block_ptr = tl .make_block_ptr (base = c_ptr , shape = (M , N ), strides = (stride_cm , stride_cn ),
103- offsets = (pid_m * BLOCK_SIZE_M , pid_n * BLOCK_SIZE_N ),
104- block_shape = (BLOCK_SIZE_M , BLOCK_SIZE_N ), order = (1 , 0 ))
105- tl .store (c_block_ptr , c , boundary_check = (0 , 1 ))
99+ c_desc = tl .make_tensor_descriptor (base = c_ptr , shape = (M , N ), strides = (stride_cm , stride_cn ),
100+ block_shape = (BLOCK_SIZE_M , BLOCK_SIZE_N ))
101+ c_desc .store ([pid_m * BLOCK_SIZE_M , pid_n * BLOCK_SIZE_N ], c )
106102
107103
108104# pylint: disable=unused-argument
@@ -130,7 +126,7 @@ def matmul_kernel_with_block_pointers(
130126 key = ['M' , 'N' , 'K' ],
131127)
132128@triton .jit
133- def matmul_kernel_with_block_pointers_batched (
129+ def matmul_kernel_with_tensor_descriptors_batched (
134130 # Pointers to matrices
135131 a_ptr , b_ptr , c_ptr , d_ptr ,
136132 # Matrix dimensions
@@ -157,33 +153,30 @@ def matmul_kernel_with_block_pointers_batched(
157153 offset_a = bid .to (tl .int64 ) * stride_az
158154 offset_b = bid .to (tl .int64 ) * stride_bz
159155
160- a_block_ptr = tl .make_block_ptr (base = a_ptr + offset_a , shape = (M , K ), strides = (stride_am , stride_ak ),
161- offsets = (pid_m * BLOCK_SIZE_M , 0 ), block_shape = (BLOCK_SIZE_M , BLOCK_SIZE_K ),
162- order = (1 , 0 ))
163- b_block_ptr = tl .make_block_ptr (base = b_ptr + offset_b , shape = (K , N ), strides = (stride_bk , stride_bn ),
164- offsets = (0 , pid_n * BLOCK_SIZE_N ), block_shape = (BLOCK_SIZE_K , BLOCK_SIZE_N ),
165- order = (1 , 0 ))
156+ a_desc = tl .make_tensor_descriptor (base = a_ptr + offset_a , shape = (M , K ), strides = (stride_am , stride_ak ),
157+ block_shape = (BLOCK_SIZE_M , BLOCK_SIZE_K ))
158+ b_desc = tl .make_tensor_descriptor (base = b_ptr + offset_b , shape = (K , N ), strides = (stride_bk , stride_bn ),
159+ block_shape = (BLOCK_SIZE_K , BLOCK_SIZE_N ))
166160
167161 accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = ACCUMULATOR_DTYPE )
162+ off_k = 0
168163 for _ in range (0 , K , BLOCK_SIZE_K ):
169- a = tl .load (a_block_ptr , boundary_check = ( 0 , 1 ) )
170- b = tl .load (b_block_ptr , boundary_check = ( 0 , 1 ) )
164+ a = a_desc .load ([ pid_m * BLOCK_SIZE_M , off_k ] )
165+ b = b_desc .load ([ off_k , pid_n * BLOCK_SIZE_N ] )
171166 accumulator += tl .dot (a , b )
172- a_block_ptr = tl .advance (a_block_ptr , (0 , BLOCK_SIZE_K ))
173- b_block_ptr = tl .advance (b_block_ptr , (BLOCK_SIZE_K , 0 ))
167+ off_k += BLOCK_SIZE_K
174168
175169 offset_d = bid .to (tl .int64 ) * stride_dz
176- d_block_ptr = tl .make_block_ptr (base = d_ptr + offset_d , shape = (M , N ), strides = (stride_dm , stride_dn ),
177- offsets = (pid_m * BLOCK_SIZE_M , pid_n * BLOCK_SIZE_N ),
178- block_shape = (BLOCK_SIZE_M , BLOCK_SIZE_N ), order = (1 , 0 ))
179- d = tl .load (d_block_ptr , boundary_check = (0 , 1 ))
170+ d_desc = tl .make_tensor_descriptor (base = d_ptr + offset_d , shape = (M , N ), strides = (stride_dm , stride_dn ),
171+ block_shape = (BLOCK_SIZE_M , BLOCK_SIZE_N ))
172+ d = d_desc .load ([pid_m * BLOCK_SIZE_M , pid_n * BLOCK_SIZE_N ])
180173 c = accumulator + d
181174
182175 offset_c = bid .to (tl .int64 ) * stride_cz
183- c_block_ptr = tl .make_block_ptr (base = c_ptr + offset_c , shape = (M , N ), strides = (stride_cm , stride_cn ),
184- offsets = ( pid_m * BLOCK_SIZE_M , pid_n * BLOCK_SIZE_N ),
185- block_shape = ( BLOCK_SIZE_M , BLOCK_SIZE_N ), order = ( 1 , 0 ))
186- tl .store (c_block_ptr , c , boundary_check = ( 0 , 1 ) )
176+ c_desc = tl .make_tensor_descriptor (base = c_ptr + offset_c , shape = (M , N ), strides = (stride_cm , stride_cn ),
177+ block_shape = ( BLOCK_SIZE_M , BLOCK_SIZE_N ))
178+
179+ c_desc .store ([ pid_m * BLOCK_SIZE_M , pid_n * BLOCK_SIZE_N ], c )
187180
188181
189182# We can now create a convenience wrapper function that only takes two input tensors,
@@ -202,7 +195,7 @@ def matmul(a, b, d, c):
202195 triton .cdiv (M , META ['BLOCK_SIZE_M' ]) * triton .cdiv (N , META ['BLOCK_SIZE_N' ]),
203196 B ,
204197 )
205- matmul_kernel_with_block_pointers_batched [grid ](
198+ matmul_kernel_with_tensor_descriptors_batched [grid ](
206199 a , b , c , d , #
207200 B , M , N , K , #
208201 a .stride (0 ), a .stride (1 ), a .stride (2 ), #
@@ -217,7 +210,7 @@ def matmul(a, b, d, c):
217210 M , K = a .shape
218211 K , N = b .shape
219212 grid = lambda META : (triton .cdiv (M , META ['BLOCK_SIZE_M' ]) * triton .cdiv (N , META ['BLOCK_SIZE_N' ]), )
220- matmul_kernel_with_block_pointers [grid ](
213+ matmul_kernel_with_tensor_descriptors [grid ](
221214 a , b , c , d , #
222215 M , N , K , #
223216 a .stride (0 ), a .stride (1 ), #
0 commit comments