@@ -101,6 +101,7 @@ def matmul_kernel(
101101 stride_cm ,
102102 stride_cn ,
103103 chunk_trun_bits ,
104+ truncate_then_accumulate ,
104105 # Meta-parameters
105106 BLOCK_SIZE_M : tl .constexpr ,
106107 BLOCK_SIZE_N : tl .constexpr ,
@@ -159,15 +160,20 @@ def matmul_kernel(
159160 a = tl .load (a_ptrs , mask = offs_k [None , :] < K - k * BLOCK_SIZE_K , other = 0.0 )
160161 b = tl .load (b_ptrs , mask = offs_k [:, None ] < K - k * BLOCK_SIZE_K , other = 0.0 )
161162 # We accumulate along the K dimension.
162- accumulator = tl .dot (a , b , accumulator , input_precision = "ieee" )
163+ if truncate_then_accumulate :
164+ accumulator_inner = tl .dot (a , b , input_precision = "ieee" )
165+ else :
166+ accumulator_inner = tl .dot (a , b , accumulator , input_precision = "ieee" )
163167 # tl.dot() default is using TF32 approximation, not good enough for LSB truncation exp
164168
165169 ## ------ add chunky LSB rounding/masking --------
166170 if chunk_trun_bits > 0 :
167- accumulator = libdevice .uint_as_float (
168- (libdevice .float_as_uint (accumulator ) + round_bit ) & trun_mask
169- )
171+ accumulator_inner = round_and_trun (accumulator_inner , round_bit , trun_mask )
170172 ## ---------------------------------------------------------
173+ if truncate_then_accumulate :
174+ accumulator += accumulator_inner
175+ else :
176+ accumulator = accumulator_inner
171177
172178 # Advance the ptrs to the next K block.
173179 a_ptrs += BLOCK_SIZE_K * stride_ak
@@ -206,6 +212,7 @@ def imatmul_kernel(
206212 stride_cm ,
207213 stride_cn ,
208214 chunk_trun_bits ,
215+ truncate_then_accumulate ,
209216 # Meta-parameters
210217 BLOCK_SIZE_M : tl .constexpr ,
211218 BLOCK_SIZE_N : tl .constexpr ,
@@ -244,13 +251,20 @@ def imatmul_kernel(
244251 for k in range (0 , tl .cdiv (K , BLOCK_SIZE_K )):
245252 a = tl .load (a_ptrs , mask = offs_k [None , :] < K - k * BLOCK_SIZE_K , other = 0.0 )
246253 b = tl .load (b_ptrs , mask = offs_k [:, None ] < K - k * BLOCK_SIZE_K , other = 0.0 )
247- accumulator = tl .dot (a , b , accumulator , input_precision = "ieee" )
254+ if truncate_then_accumulate :
255+ accumulator_inner = tl .dot (a , b , input_precision = "ieee" )
256+ else :
257+ accumulator_inner = tl .dot (a , b , accumulator , input_precision = "ieee" )
248258
249259 ## ------ add chunky LSB rounding/masking --------
250260 if chunk_trun_bits != 0 :
251- accumulator = (accumulator + round_bit ) >> chunk_trun_bits
252- accumulator = accumulator << chunk_trun_bits
261+ accumulator_inner = (accumulator_inner + round_bit ) >> chunk_trun_bits
262+ accumulator_inner = accumulator_inner << chunk_trun_bits
253263 ## ---------------------------------------------------------
264+ if truncate_then_accumulate :
265+ accumulator += accumulator_inner
266+ else :
267+ accumulator = accumulator_inner
254268
255269 a_ptrs += BLOCK_SIZE_K * stride_ak
256270 b_ptrs += BLOCK_SIZE_K * stride_bk
@@ -266,29 +280,162 @@ def imatmul_kernel(
266280 tl .store (c_ptrs , c , mask = c_mask )
267281
268282
283+ @triton .jit
284+ def matmul_kernel_DABC (
285+ # Pointers to matrices
286+ a_ptr ,
287+ b_ptr ,
288+ c_ptr ,
289+ # Matrix dimensions
290+ M ,
291+ N ,
292+ K ,
293+ # The stride variables represent how much to increase the ptr by when moving by 1
294+ # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
295+ # by to get the element one row down (A has M rows).
296+ stride_am ,
297+ stride_ak ,
298+ stride_bk ,
299+ stride_bn ,
300+ stride_cm ,
301+ stride_cn ,
302+ chunk_trun_bits ,
303+ truncate_then_accumulate ,
304+ # Meta-parameters
305+ BLOCK_SIZE_M : tl .constexpr ,
306+ BLOCK_SIZE_N : tl .constexpr ,
307+ BLOCK_SIZE_K : tl .constexpr ,
308+ GROUP_SIZE_M : tl .constexpr ,
309+ ACTIVATION : tl .constexpr ,
310+ ):
311+ """Kernel for computing the matmul D = A x B + C that include LSB truncation.
312+ A has shape (M, K), B has shape (K, N) and C/D has shape (M, N).
313+ NOTE:
314+ C should be consistent with accumulator dtype, e.g. fp8xfp8 -> fp32.
315+ *D ptr is supposed to be the same as C ptr, no need to provide D as arg
316+ **we can be used C to verify unintended truncation by CUDA as well.
317+ Args:
318+ chunk_trun_bits (int): number of LSB to truncate/round. [0 to 23]
319+ """
320+ # -----------------------------------------------------------
321+ # Map program ids `pid` to the block of C it should compute.
322+ # This is done in a grouped ordering to promote L2 data reuse.
323+ # See above `L2 Cache Optimizations` section for details.
324+ pid = tl .program_id (axis = 0 )
325+ num_pid_m = tl .cdiv (M , BLOCK_SIZE_M )
326+ num_pid_n = tl .cdiv (N , BLOCK_SIZE_N )
327+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
328+ group_id = pid // num_pid_in_group
329+ first_pid_m = group_id * GROUP_SIZE_M
330+ group_size_m = min (num_pid_m - first_pid_m , GROUP_SIZE_M )
331+ pid_m = first_pid_m + ((pid % num_pid_in_group ) % group_size_m )
332+ pid_n = (pid % num_pid_in_group ) // group_size_m
333+
334+ # ----------------------------------------------------------
335+ # Create pointers for the first blocks of A and B.
336+ # We will advance this pointer as we move in the K direction
337+ # and accumulate
338+ # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
339+ # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
340+ # See above `Pointer Arithmetic` section for details
341+ offs_am = (pid_m * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )) % M
342+ offs_bn = (pid_n * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )) % N
343+ offs_cm = pid_m * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
344+ offs_cn = pid_n * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )
345+ offs_k = tl .arange (0 , BLOCK_SIZE_K )
346+ a_ptrs = a_ptr + (offs_am [:, None ] * stride_am + offs_k [None , :] * stride_ak )
347+ b_ptrs = b_ptr + (offs_k [:, None ] * stride_bk + offs_bn [None , :] * stride_bn )
348+ c_ptrs = c_ptr + stride_cm * offs_cm [:, None ] + stride_cn * offs_cn [None , :]
349+ c_mask = (offs_cm [:, None ] < M ) & (offs_cn [None , :] < N )
350+
351+ # -----------------------------------------------------------
352+ # Iterate to compute a block of the C matrix.
353+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
354+ # of fp32 values for higher accuracy, i.e. C should have been cast to fp32 already
355+ accumulator = tl .load (c_ptrs , mask = c_mask , other = 0.0 )
356+ ## ------ prepare LSB rounding/truncation masks -------
357+ # NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b
358+ # e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
359+ # 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080
360+ trun_mask = tl .cast ((0xFFFFFFFF >> chunk_trun_bits ) << chunk_trun_bits , tl .uint32 )
361+ round_bit = 1 << (chunk_trun_bits - 1 ) if chunk_trun_bits > 0 else 0
362+ ## ---------------------------------------------------------
363+
364+ for k in range (0 , tl .cdiv (K , BLOCK_SIZE_K )):
365+ # Load the next block of A, B, and C, generate a mask by checking the K dimension.
366+ # If it is out of bounds, set it to 0.
367+ # D = truncation(A*B) + C
368+ a = tl .load (a_ptrs , mask = offs_k [None , :] < K - k * BLOCK_SIZE_K , other = 0.0 )
369+ b = tl .load (b_ptrs , mask = offs_k [:, None ] < K - k * BLOCK_SIZE_K , other = 0.0 )
370+ # We accumulate along the K dimension. but apply truncation on local A*B first
371+ if truncate_then_accumulate :
372+ accumulator_inner = tl .dot (a , b , input_precision = "ieee" )
373+ else :
374+ accumulator_inner = tl .dot (a , b , accumulator , input_precision = "ieee" )
375+ # tl.dot() default is using TF32 approximation, not good enough for LSB truncation exp
376+ # NOTE: tl.dot(a, b, c) should correspond to a CUDA mma instruction, typically "c = a*b+c".
377+ # If this mma instruction uses "reduced-precision" under the hood, not only a*b will
378+ # be accumulated in that precision, there's a chance c will be cast to that "lower"
379+ # precision as well, hence, could lose some precision!
380+
381+ ## ------ add chunky LSB rounding/masking --------
382+ if chunk_trun_bits > 0 :
383+ accumulator_inner = round_and_trun (accumulator_inner , round_bit , trun_mask )
384+ ## ---------------------------------------------------------
385+ if truncate_then_accumulate :
386+ accumulator += accumulator_inner
387+ else :
388+ accumulator = accumulator_inner
389+
390+ # Advance the ptrs to the next K block.
391+ a_ptrs += BLOCK_SIZE_K * stride_ak
392+ b_ptrs += BLOCK_SIZE_K * stride_bk
393+ # You can fuse arbitrary activation functions here
394+ # while the accumulator is still in FP32!
395+ if ACTIVATION == "leaky_relu" :
396+ accumulator = leaky_relu (accumulator )
397+
398+ d = accumulator # do not cast to (tl.float16) just yet
399+
400+ # -----------------------------------------------------------
401+ # Write back the block of the output to matrix "C" with masks.
402+ tl .store (c_ptrs , d , mask = c_mask )
403+
404+
269405@triton .jit
270406def leaky_relu (x ):
271407 """Activation function that could be fused into matmul kernel"""
272408 return tl .where (x >= 0 , x , 0.01 * x )
273409
274410
411+ @triton .jit
412+ def round_and_trun (x , round_bit , trun_mask ):
413+ """Round and truncate (usually for accumulator)."""
414+ return libdevice .uint_as_float ((libdevice .float_as_uint (x ) + round_bit ) & trun_mask )
415+
416+
275417def tl_matmul_chunk_truncate (
276418 a ,
277419 b ,
420+ c = None ,
278421 activation = "" ,
279422 chunk_trun_bits = 0 ,
280423 chunk_size = 16 ,
424+ truncate_then_accumulate = True ,
281425 cast_output_to_input_dtype = None ,
282426):
283427 """Triton matmul for HW behavior simulation. Supports float and int8.
284- a. variable chunk size (i.e., BLOCK_SIZE_K)
285- b. LSB truncation, must <23 if using float.
428+ i. variable chunk size (i.e., BLOCK_SIZE_K)
429+ ii. LSB truncation, must <23 if using float.
430+ iii. assume D = A*B + C, where C is optional. If C exists, it will be updated inplace.
286431
287432 Args:
288433 a, b: input tensors. FloatX, X in [32, 16, 8] or INT8.
289434 activation (str, optional): activation func to be fused, see relu example.
290435 chunk_trun_bits (int, optional): number of LSBs to be truncated/rounded.
291436 chunk_size (int, optional): BLOCK_SIZE_K, some HW has specific chunk size. must >= 16.
437+ truncate_then_accumulate (bool, optional): if True, c = truncate(a*b) + c, otherwise
438+ c = truncate(a*b+c)
292439 cast_output_to_input_dtype (bool, optional): accumulator has higher prec than input, usually
293440 FP32 or INT32. by default we cast the final
294441 output to the same dtype as input for non-8bits.
@@ -300,6 +447,7 @@ def tl_matmul_chunk_truncate(
300447 use empirical way to determine BLOCK sizes, may not be optimal. But need to avoid autotune for
301448 real model inference. otherwise auto-tune will be triggered in every forward call.
302449 """
450+
303451 # Check constraints.
304452 assert a .shape [1 ] == b .shape [0 ], "Incompatible dimensions"
305453 assert a .is_contiguous (), "Matrix A must be contiguous"
@@ -314,28 +462,60 @@ def tl_matmul_chunk_truncate(
314462 if cuda_cc [0 ] >= 9 or cuda_cc == (8 , 9 ):
315463 allowed_dtypes += DTYPE_F8
316464 assert a .dtype in allowed_dtypes , "Input dtype is not supported"
317- M , K = a .shape
318- K , N = b .shape
319465
320466 # Allocates output, always accumulate in FP32 (if floats) or INT32 then cast
321467 def isPowerofTwo (x ):
322468 """triton-specific limitation: block size needs to be power of 2."""
323469 return (x & (x - 1 )) == 0
324470
325471 min_chunk_size = 32 if a .dtype in DTYPE_8BIT else 16
326- if isPowerofTwo (chunk_size ):
327- chunk_size = max (chunk_size , min_chunk_size )
328- else :
472+
473+ # because min k (chunk size in this case) for fp16/bf16 is 16, if smaller is needed, we could
474+ # insert 0s in between elements, e.g. pad [m,k] -> [m,2k], [k,n]->[2k,n], out=[m,n] unchanged.
475+ # Do not support INT8 for now.
476+ if chunk_size == 8 and a .dtype in [
477+ torch .float8_e4m3fn ,
478+ torch .float16 ,
479+ torch .bfloat16 ,
480+ ]:
481+ exp_ratio = min_chunk_size // chunk_size
482+ a_padded = torch .zeros (
483+ a .shape [0 ], a .shape [1 ] * exp_ratio , dtype = a .dtype , device = a .device
484+ )
485+ a_padded [:, ::exp_ratio ] = a
486+ a = a_padded
487+ b_padded = torch .zeros (
488+ b .shape [0 ] * exp_ratio , b .shape [1 ], dtype = b .dtype , device = b .device
489+ )
490+ b_padded [::exp_ratio , :] = b
491+ b = b_padded
329492 chunk_size = min_chunk_size
493+ else :
494+ chunk_size = (
495+ max (chunk_size , min_chunk_size )
496+ if isPowerofTwo (chunk_size )
497+ else min_chunk_size
498+ )
330499
500+ M , K = a .shape
501+ K , N = b .shape
331502 if a .dtype in DTYPE_I8 :
332503 acc_dtype = torch .int32
333504 mm_kernel = imatmul_kernel
334505 else :
335506 acc_dtype = torch .float32
336- mm_kernel = matmul_kernel
507+ mm_kernel = matmul_kernel if c is None else matmul_kernel_DABC
337508 assert chunk_trun_bits < 23 , "FP32 accumulator only has 23 mantissa bits"
338- c = torch .zeros ((M , N ), device = a .device , dtype = acc_dtype )
509+
510+ if c is None :
511+ c_org_dtype = a .dtype
512+ c = torch .zeros ((M , N ), device = a .device , dtype = acc_dtype )
513+ else :
514+ # if C is in fp16, accumulate in fp32 no matter what, decide whether to cast back later
515+ c_org_dtype = c .dtype
516+ c = c .to (acc_dtype )
517+ assert c .shape [0 ] == M and c .shape [1 ] == N , "C shape is inconsistent with A B."
518+ assert acc_dtype == torch .float32 , "INT truncation is not yet supported."
339519
340520 # 1D launch kernel where each block gets its own program.
341521 def grid (META ):
@@ -345,7 +525,7 @@ def grid(META):
345525
346526 if M < 1024 or N < 1024 :
347527 kernel_config = {
348- "BLOCK_SIZE_M" : 128 ,
528+ "BLOCK_SIZE_M" : 64 ,
349529 "BLOCK_SIZE_K" : chunk_size ,
350530 "BLOCK_SIZE_N" : 32 ,
351531 "GROUP_SIZE_M" : 8 ,
@@ -376,7 +556,8 @@ def grid(META):
376556 c .stride (0 ),
377557 c .stride (1 ),
378558 chunk_trun_bits = chunk_trun_bits ,
559+ truncate_then_accumulate = truncate_then_accumulate ,
379560 ACTIVATION = activation ,
380561 ** kernel_config , # if using auto-tune, comment this line out.
381562 )
382- return c .to (a . dtype ) if cast_output_to_input_dtype else c
563+ return c .to (c_org_dtype ) if cast_output_to_input_dtype else c
0 commit comments