@@ -101,6 +101,7 @@ def matmul_kernel(
101101 stride_cm ,
102102 stride_cn ,
103103 chunk_trun_bits ,
104+ truncate_then_accumulate ,
104105 # Meta-parameters
105106 BLOCK_SIZE_M : tl .constexpr ,
106107 BLOCK_SIZE_N : tl .constexpr ,
@@ -159,15 +160,20 @@ def matmul_kernel(
159160 a = tl .load (a_ptrs , mask = offs_k [None , :] < K - k * BLOCK_SIZE_K , other = 0.0 )
160161 b = tl .load (b_ptrs , mask = offs_k [:, None ] < K - k * BLOCK_SIZE_K , other = 0.0 )
161162 # We accumulate along the K dimension.
162- accumulator = tl .dot (a , b , accumulator , input_precision = "ieee" )
163+ if truncate_then_accumulate :
164+ accumulator_inner = tl .dot (a , b , input_precision = "ieee" )
165+ else :
166+ accumulator_inner = tl .dot (a , b , accumulator , input_precision = "ieee" )
163167 # tl.dot() default is using TF32 approximation, not good enough for LSB truncation exp
164168
165169 ## ------ add chunky LSB rounding/masking --------
166170 if chunk_trun_bits > 0 :
167- accumulator = libdevice .uint_as_float (
168- (libdevice .float_as_uint (accumulator ) + round_bit ) & trun_mask
169- )
171+ accumulator_inner = round_and_trun (accumulator_inner , round_bit , trun_mask )
170172 ## ---------------------------------------------------------
173+ if truncate_then_accumulate :
174+ accumulator += accumulator_inner
175+ else :
176+ accumulator = accumulator_inner
171177
172178 # Advance the ptrs to the next K block.
173179 a_ptrs += BLOCK_SIZE_K * stride_ak
@@ -206,6 +212,7 @@ def imatmul_kernel(
206212 stride_cm ,
207213 stride_cn ,
208214 chunk_trun_bits ,
215+ truncate_then_accumulate ,
209216 # Meta-parameters
210217 BLOCK_SIZE_M : tl .constexpr ,
211218 BLOCK_SIZE_N : tl .constexpr ,
@@ -244,13 +251,21 @@ def imatmul_kernel(
244251 for k in range (0 , tl .cdiv (K , BLOCK_SIZE_K )):
245252 a = tl .load (a_ptrs , mask = offs_k [None , :] < K - k * BLOCK_SIZE_K , other = 0.0 )
246253 b = tl .load (b_ptrs , mask = offs_k [:, None ] < K - k * BLOCK_SIZE_K , other = 0.0 )
247- accumulator = tl .dot (a , b , accumulator , input_precision = "ieee" )
254+ if truncate_then_accumulate :
255+ accumulator_inner = tl .dot (a , b , input_precision = "ieee" )
256+ else :
257+ accumulator_inner = tl .dot (a , b , accumulator , input_precision = "ieee" )
248258
249259 ## ------ add chunky LSB rounding/masking --------
250260 if chunk_trun_bits != 0 :
251- accumulator = (accumulator + round_bit ) >> chunk_trun_bits
252- accumulator = accumulator << chunk_trun_bits
261+ accumulator_inner = (accumulator_inner + round_bit ) >> chunk_trun_bits
262+ accumulator_inner = accumulator_inner << chunk_trun_bits
253263 ## ---------------------------------------------------------
264+ if truncate_then_accumulate :
265+ accumulator += accumulator_inner
266+ else :
267+ accumulator = accumulator_inner
268+
254269
255270 a_ptrs += BLOCK_SIZE_K * stride_ak
256271 b_ptrs += BLOCK_SIZE_K * stride_bk
@@ -266,29 +281,163 @@ def imatmul_kernel(
266281 tl .store (c_ptrs , c , mask = c_mask )
267282
268283
284+
285+ @triton .jit
286+ def matmul_kernel_DABC (
287+ # Pointers to matrices
288+ a_ptr ,
289+ b_ptr ,
290+ c_ptr ,
291+ # Matrix dimensions
292+ M ,
293+ N ,
294+ K ,
295+ # The stride variables represent how much to increase the ptr by when moving by 1
296+ # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
297+ # by to get the element one row down (A has M rows).
298+ stride_am ,
299+ stride_ak ,
300+ stride_bk ,
301+ stride_bn ,
302+ stride_cm ,
303+ stride_cn ,
304+ chunk_trun_bits ,
305+ truncate_then_accumulate ,
306+ # Meta-parameters
307+ BLOCK_SIZE_M : tl .constexpr ,
308+ BLOCK_SIZE_N : tl .constexpr ,
309+ BLOCK_SIZE_K : tl .constexpr ,
310+ GROUP_SIZE_M : tl .constexpr ,
311+ ACTIVATION : tl .constexpr ,
312+ ):
313+ """Kernel for computing the matmul D = A x B + C that include LSB truncation.
314+ A has shape (M, K), B has shape (K, N) and C/D has shape (M, N).
315+ NOTE:
316+ C should be consistent with accumulator dtype, e.g. fp8xfp8 -> fp32.
317+ *D ptr is supposed to be the same as C ptr, no need to provide D as arg
318+ **we can be used C to verify unintended truncation by CUDA as well.
319+ Args:
320+ chunk_trun_bits (int): number of LSB to truncate/round. [0 to 23]
321+ """
322+ # -----------------------------------------------------------
323+ # Map program ids `pid` to the block of C it should compute.
324+ # This is done in a grouped ordering to promote L2 data reuse.
325+ # See above `L2 Cache Optimizations` section for details.
326+ pid = tl .program_id (axis = 0 )
327+ num_pid_m = tl .cdiv (M , BLOCK_SIZE_M )
328+ num_pid_n = tl .cdiv (N , BLOCK_SIZE_N )
329+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
330+ group_id = pid // num_pid_in_group
331+ first_pid_m = group_id * GROUP_SIZE_M
332+ group_size_m = min (num_pid_m - first_pid_m , GROUP_SIZE_M )
333+ pid_m = first_pid_m + ((pid % num_pid_in_group ) % group_size_m )
334+ pid_n = (pid % num_pid_in_group ) // group_size_m
335+
336+ # ----------------------------------------------------------
337+ # Create pointers for the first blocks of A and B.
338+ # We will advance this pointer as we move in the K direction
339+ # and accumulate
340+ # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
341+ # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
342+ # See above `Pointer Arithmetic` section for details
343+ offs_am = (pid_m * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )) % M
344+ offs_bn = (pid_n * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )) % N
345+ offs_cm = pid_m * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
346+ offs_cn = pid_n * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )
347+ offs_k = tl .arange (0 , BLOCK_SIZE_K )
348+ a_ptrs = a_ptr + (offs_am [:, None ] * stride_am + offs_k [None , :] * stride_ak )
349+ b_ptrs = b_ptr + (offs_k [:, None ] * stride_bk + offs_bn [None , :] * stride_bn )
350+ c_ptrs = c_ptr + stride_cm * offs_cm [:, None ] + stride_cn * offs_cn [None , :]
351+ c_mask = (offs_cm [:, None ] < M ) & (offs_cn [None , :] < N )
352+
353+ # -----------------------------------------------------------
354+ # Iterate to compute a block of the C matrix.
355+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
356+ # of fp32 values for higher accuracy.
357+ # accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
358+ accumulator = tl .load (c_ptrs , mask = c_mask , other = 0.0 ) # should have been cast to fp32 already
359+ ## ------ prepare LSB rounding/truncation masks -------
360+ # NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b
361+ # e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
362+ # 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080
363+ trun_mask = tl .cast ((0xFFFFFFFF >> chunk_trun_bits ) << chunk_trun_bits , tl .uint32 )
364+ round_bit = 1 << (chunk_trun_bits - 1 ) if chunk_trun_bits > 0 else 0
365+ ## ---------------------------------------------------------
366+
367+ for k in range (0 , tl .cdiv (K , BLOCK_SIZE_K )):
368+ # Load the next block of A, B, and C, generate a mask by checking the K dimension.
369+ # If it is out of bounds, set it to 0.
370+ # D = truncation(A*B) + C
371+ a = tl .load (a_ptrs , mask = offs_k [None , :] < K - k * BLOCK_SIZE_K , other = 0.0 )
372+ b = tl .load (b_ptrs , mask = offs_k [:, None ] < K - k * BLOCK_SIZE_K , other = 0.0 )
373+ # We accumulate along the K dimension. but apply truncation on local A*B first
374+ if truncate_then_accumulate :
375+ accumulator_inner = tl .dot (a , b , input_precision = "ieee" )
376+ else :
377+ accumulator_inner = tl .dot (a , b , accumulator , input_precision = "ieee" )
378+ # tl.dot() default is using TF32 approximation, not good enough for LSB truncation exp
379+ # NOTE: tl.dot(a, b, c) should use one single CUDA mma instruction to handle "c = a*b+c". If
380+ # this mma instruction uses "reduced-precision" under the hood, not only a*b will
381+ # be accumulated in that precision, c most likely will be cast to that "lower"
382+ # precision first, hence, will lose some precision!
383+
384+ ## ------ add chunky LSB rounding/masking --------
385+ if chunk_trun_bits > 0 :
386+ accumulator_inner = round_and_trun (accumulator_inner , round_bit , trun_mask )
387+ ## ---------------------------------------------------------
388+ if truncate_then_accumulate :
389+ accumulator += accumulator_inner
390+ else :
391+ accumulator = accumulator_inner
392+
393+ # Advance the ptrs to the next K block.
394+ a_ptrs += BLOCK_SIZE_K * stride_ak
395+ b_ptrs += BLOCK_SIZE_K * stride_bk
396+ # You can fuse arbitrary activation functions here
397+ # while the accumulator is still in FP32!
398+ if ACTIVATION == "leaky_relu" :
399+ accumulator = leaky_relu (accumulator )
400+
401+ d = accumulator # do not cast to (tl.float16) just yet
402+
403+ # -----------------------------------------------------------
404+ # Write back the block of the output to matrix "C" with masks.
405+ tl .store (c_ptrs , d , mask = c_mask )
406+
407+
269408@triton .jit
270409def leaky_relu (x ):
271410 """Activation function that could be fused into matmul kernel"""
272411 return tl .where (x >= 0 , x , 0.01 * x )
273412
274413
414+ @triton .jit
415+ def round_and_trun (x , round_bit , trun_mask ):
416+ """Round and truncate (usually for accumulator)."""
417+ return libdevice .uint_as_float ((libdevice .float_as_uint (x ) + round_bit ) & trun_mask )
418+
419+
275420def tl_matmul_chunk_truncate (
276421 a ,
277422 b ,
278423 activation = "" ,
279424 chunk_trun_bits = 0 ,
280425 chunk_size = 16 ,
426+ truncate_then_accumulate = True ,
281427 cast_output_to_input_dtype = None ,
282428):
283429 """Triton matmul for HW behavior simulation. Supports float and int8.
284- a. variable chunk size (i.e., BLOCK_SIZE_K)
285- b. LSB truncation, must <23 if using float.
430+ i. variable chunk size (i.e., BLOCK_SIZE_K)
431+ ii. LSB truncation, must <23 if using float.
432+ iii. assume D = A*B + C, where C is optional. If C exists, it will be updated inplace.
286433
287434 Args:
288435 a, b: input tensors. FloatX, X in [32, 16, 8] or INT8.
289436 activation (str, optional): activation func to be fused, see relu example.
290437 chunk_trun_bits (int, optional): number of LSBs to be truncated/rounded.
291438 chunk_size (int, optional): BLOCK_SIZE_K, some HW has specific chunk size. must >= 16.
439+ truncate_then_accumulate (bool, optional): if True, c = truncate(a*b) + c, otherwise
440+ c = truncate(a*b+c)
292441 cast_output_to_input_dtype (bool, optional): accumulator has higher prec than input, usually
293442 FP32 or INT32. by default we cast the final
294443 output to the same dtype as input for non-8bits.
@@ -300,6 +449,7 @@ def tl_matmul_chunk_truncate(
300449 use empirical way to determine BLOCK sizes, may not be optimal. But need to avoid autotune for
301450 real model inference. otherwise auto-tune will be triggered in every forward call.
302451 """
452+
303453 # Check constraints.
304454 assert a .shape [1 ] == b .shape [0 ], "Incompatible dimensions"
305455 assert a .is_contiguous (), "Matrix A must be contiguous"
@@ -343,9 +493,18 @@ def isPowerofTwo(x):
343493 mm_kernel = imatmul_kernel
344494 else :
345495 acc_dtype = torch .float32
346- mm_kernel = matmul_kernel
496+ mm_kernel = matmul_kernel if c == None else matmul_kernel_DABC
347497 assert chunk_trun_bits < 23 , "FP32 accumulator only has 23 mantissa bits"
348- c = torch .zeros ((M , N ), device = a .device , dtype = acc_dtype )
498+
499+ if c == None :
500+ c_org_dtype = a .dtype
501+ c = torch .zeros ((M , N ), device = a .device , dtype = acc_dtype )
502+ else :
503+ # if C is in fp16, accumulate in fp32 no matter what, decide whether to cast back later
504+ c_org_dtype = c .dtype
505+ c = c .to (acc_dtype )
506+ assert c .shape [0 ]== M and c .shape [1 ]== N , "C shape is inconsistent with A B."
507+ assert acc_dtype == torch .float32 , "INT truncation experiment is not yet supported."
349508
350509 # 1D launch kernel where each block gets its own program.
351510 def grid (META ):
@@ -386,7 +545,8 @@ def grid(META):
386545 c .stride (0 ),
387546 c .stride (1 ),
388547 chunk_trun_bits = chunk_trun_bits ,
548+ truncate_then_accumulate = truncate_then_accumulate ,
389549 ACTIVATION = activation ,
390550 ** kernel_config , # if using auto-tune, comment this line out.
391551 )
392- return c .to (a . dtype ) if cast_output_to_input_dtype else c
552+ return c .to (c_org_dtype ) if cast_output_to_input_dtype else c
0 commit comments