@@ -101,6 +101,7 @@ def matmul_kernel(
101101 stride_cm ,
102102 stride_cn ,
103103 chunk_trun_bits ,
104+ max_acc_bits , # pylint: disable=unused-argument
104105 truncate_then_accumulate ,
105106 # Meta-parameters
106107 BLOCK_SIZE_M : tl .constexpr ,
@@ -212,6 +213,7 @@ def imatmul_kernel(
212213 stride_cm ,
213214 stride_cn ,
214215 chunk_trun_bits ,
216+ max_acc_bits ,
215217 truncate_then_accumulate ,
216218 # Meta-parameters
217219 BLOCK_SIZE_M : tl .constexpr ,
@@ -220,8 +222,8 @@ def imatmul_kernel(
220222 GROUP_SIZE_M : tl .constexpr ,
221223 ACTIVATION : tl .constexpr ,
222224):
223- """Kernel for computing the INT matmul C = A x B that include LSB truncation. A and B should be
224- INT8, C should be INT32. (Pretty much the same code as float version.)
225+ """Kernel for computing the INT matmul D = A x B + C that include LSB truncation and MSB
226+ clamping. A and B should be INT8, C/D should be INT32. (similar to the float version.)
225227 A has shape (M, K), B has shape (K, N) and C has shape (M, N)
226228 Args:
227229 chunk_trun_bits (int): number of LSBs to truncate/round.
@@ -238,14 +240,20 @@ def imatmul_kernel(
238240
239241 offs_am = (pid_m * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )) % M
240242 offs_bn = (pid_n * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )) % N
243+ offs_cm = pid_m * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
244+ offs_cn = pid_n * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )
241245 offs_k = tl .arange (0 , BLOCK_SIZE_K )
242246 a_ptrs = a_ptr + (offs_am [:, None ] * stride_am + offs_k [None , :] * stride_ak )
243247 b_ptrs = b_ptr + (offs_k [:, None ] * stride_bk + offs_bn [None , :] * stride_bn )
248+ c_ptrs = c_ptr + stride_cm * offs_cm [:, None ] + stride_cn * offs_cn [None , :]
249+ c_mask = (offs_cm [:, None ] < M ) & (offs_cn [None , :] < N )
244250
245- accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = tl .int32 )
246- ## ------ prepare LSB rounding/truncation masks -------
251+ # accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)
252+ accumulator = tl .load (c_ptrs , mask = c_mask , other = 0.0 )
253+ ## ------ prepare MSB/LSB rounding/truncation masks -------
247254 round_bit = 1 << (chunk_trun_bits - 1 ) if chunk_trun_bits > 0 else 0
248- # msb_mask = 0x00FFFFFF # only needed when simulating truncation on MSB
255+ acc_min = - (1 << (max_acc_bits - 1 ))
256+ acc_max = - acc_min - 1
249257 ## ---------------------------------------------------------
250258
251259 for k in range (0 , tl .cdiv (K , BLOCK_SIZE_K )):
@@ -256,7 +264,12 @@ def imatmul_kernel(
256264 else :
257265 accumulator_inner = tl .dot (a , b , accumulator , input_precision = "ieee" )
258266
259- ## ------ add chunky LSB rounding/masking --------
267+ ## ------ INT MSB truncation is simulated by clamping,
268+ # "special" INT LSB truncation by right and left shift --------
269+ if max_acc_bits < 32 :
270+ accumulator_inner = tl .maximum (
271+ tl .minimum (accumulator_inner , acc_max ), acc_min
272+ )
260273 if chunk_trun_bits != 0 :
261274 accumulator_inner = (accumulator_inner + round_bit ) >> chunk_trun_bits
262275 accumulator_inner = accumulator_inner << chunk_trun_bits
@@ -275,8 +288,6 @@ def imatmul_kernel(
275288
276289 offs_cm = pid_m * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
277290 offs_cn = pid_n * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )
278- c_ptrs = c_ptr + stride_cm * offs_cm [:, None ] + stride_cn * offs_cn [None , :]
279- c_mask = (offs_cm [:, None ] < M ) & (offs_cn [None , :] < N )
280291 tl .store (c_ptrs , c , mask = c_mask )
281292
282293
@@ -300,6 +311,7 @@ def matmul_kernel_DABC(
300311 stride_cm ,
301312 stride_cn ,
302313 chunk_trun_bits ,
314+ max_acc_bits , # pylint: disable=unused-argument
303315 truncate_then_accumulate ,
304316 # Meta-parameters
305317 BLOCK_SIZE_M : tl .constexpr ,
@@ -421,6 +433,7 @@ def tl_matmul_chunk_truncate(
421433 activation = "" ,
422434 chunk_trun_bits = 0 ,
423435 chunk_size = 16 ,
436+ max_acc_bits = 32 ,
424437 truncate_then_accumulate = True ,
425438 cast_output_to_input_dtype = None ,
426439):
@@ -434,6 +447,9 @@ def tl_matmul_chunk_truncate(
434447 activation (str, optional): activation func to be fused, see relu example.
435448 chunk_trun_bits (int, optional): number of LSBs to be truncated/rounded.
436449 chunk_size (int, optional): BLOCK_SIZE_K, some HW has specific chunk size. must >= 16.
450+ max_acc_bits (int, optional): num of bits for the accumulator, e.g. if INT24 is used, will
451+ clamp each chunk of a*b to [-2**23-1, 2**23].
452+ (assuming no inf when overflow)
437453 truncate_then_accumulate (bool, optional): if True, c = truncate(a*b) + c, otherwise
438454 c = truncate(a*b+c)
439455 cast_output_to_input_dtype (bool, optional): accumulator has higher prec than input, usually
@@ -472,9 +488,9 @@ def isPowerofTwo(x):
472488
473489 # because min k (chunk size in this case) for fp16/bf16 is 16, if smaller is needed, we could
474490 # insert 0s in between elements, e.g. pad [m,k] -> [m,2k], [k,n]->[2k,n], out=[m,n] unchanged.
475- # Do not support INT8 for now.
476491 if chunk_size == 8 and a .dtype in [
477492 torch .float8_e4m3fn ,
493+ torch .int8 ,
478494 torch .float16 ,
479495 torch .bfloat16 ,
480496 ]:
@@ -515,7 +531,6 @@ def isPowerofTwo(x):
515531 c_org_dtype = c .dtype
516532 c = c .to (acc_dtype )
517533 assert c .shape [0 ] == M and c .shape [1 ] == N , "C shape is inconsistent with A B."
518- assert acc_dtype == torch .float32 , "INT truncation is not yet supported."
519534
520535 # 1D launch kernel where each block gets its own program.
521536 def grid (META ):
@@ -556,6 +571,7 @@ def grid(META):
556571 c .stride (0 ),
557572 c .stride (1 ),
558573 chunk_trun_bits = chunk_trun_bits ,
574+ max_acc_bits = max_acc_bits ,
559575 truncate_then_accumulate = truncate_then_accumulate ,
560576 ACTIVATION = activation ,
561577 ** kernel_config , # if using auto-tune, comment this line out.
0 commit comments