@@ -266,7 +266,6 @@ def imatmul_kernel(
266266 else :
267267 accumulator = accumulator_inner
268268
269-
270269 a_ptrs += BLOCK_SIZE_K * stride_ak
271270 b_ptrs += BLOCK_SIZE_K * stride_bk
272271 if ACTIVATION == "leaky_relu" :
@@ -281,7 +280,6 @@ def imatmul_kernel(
281280 tl .store (c_ptrs , c , mask = c_mask )
282281
283282
284-
285283@triton .jit
286284def matmul_kernel_DABC (
287285 # Pointers to matrices
@@ -311,11 +309,11 @@ def matmul_kernel_DABC(
311309 ACTIVATION : tl .constexpr ,
312310):
313311 """Kernel for computing the matmul D = A x B + C that include LSB truncation.
314- A has shape (M, K), B has shape (K, N) and C/D has shape (M, N).
312+ A has shape (M, K), B has shape (K, N) and C/D has shape (M, N).
315313 NOTE:
316314 C should be consistent with accumulator dtype, e.g. fp8xfp8 -> fp32.
317315 *D ptr is supposed to be the same as C ptr, no need to provide D as arg
318- **we can be used C to verify unintended truncation by CUDA as well.
316+ **we can be used C to verify unintended truncation by CUDA as well.
319317 Args:
320318 chunk_trun_bits (int): number of LSB to truncate/round. [0 to 23]
321319 """
@@ -353,9 +351,8 @@ def matmul_kernel_DABC(
353351 # -----------------------------------------------------------
354352 # Iterate to compute a block of the C matrix.
355353 # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
356- # of fp32 values for higher accuracy.
357- # accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
358- accumulator = tl .load (c_ptrs , mask = c_mask , other = 0.0 ) # should have been cast to fp32 already
354+ # of fp32 values for higher accuracy, i.e. C should have been cast to fp32 already
355+ accumulator = tl .load (c_ptrs , mask = c_mask , other = 0.0 )
359356 ## ------ prepare LSB rounding/truncation masks -------
360357 # NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b
361358 # e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
@@ -477,15 +474,23 @@ def isPowerofTwo(x):
477474 # insert 0s in between elements, i.e. pad [m,k] -> [m,2k], [k,n]->[2k,n], out=[m,n] unchanged.
478475 # Do not support I8 or F8 for now. (as F8/FP24 simulation is treated as BF16 currently)
479476 if chunk_size == 8 and a .dtype in [torch .float16 , torch .bfloat16 ]:
480- a_padded = torch .zeros (a .shape [0 ], a .shape [1 ]* 2 , dtype = a .dtype , device = a .device )
477+ a_padded = torch .zeros (
478+ a .shape [0 ], a .shape [1 ] * 2 , dtype = a .dtype , device = a .device
479+ )
481480 a_padded [:, ::2 ] = a
482481 a = a_padded
483- b_padded = torch .zeros (b .shape [0 ]* 2 , b .shape [1 ], dtype = b .dtype , device = b .device )
482+ b_padded = torch .zeros (
483+ b .shape [0 ] * 2 , b .shape [1 ], dtype = b .dtype , device = b .device
484+ )
484485 b_padded [::2 , :] = b
485486 b = b_padded
486487 chunk_size = 16
487488 else :
488- chunk_size = max (chunk_size , min_chunk_size ) if isPowerofTwo (chunk_size ) else min_chunk_size
489+ chunk_size = (
490+ max (chunk_size , min_chunk_size )
491+ if isPowerofTwo (chunk_size )
492+ else min_chunk_size
493+ )
489494
490495 M , K = a .shape
491496 K , N = b .shape
@@ -504,8 +509,8 @@ def isPowerofTwo(x):
504509 # if C is in fp16, accumulate in fp32 no matter what, decide whether to cast back later
505510 c_org_dtype = c .dtype
506511 c = c .to (acc_dtype )
507- assert c .shape [0 ]== M and c .shape [1 ]== N , "C shape is inconsistent with A B."
508- assert acc_dtype == torch .float32 , "INT truncation experiment is not yet supported."
512+ assert c .shape [0 ] == M and c .shape [1 ] == N , "C shape is inconsistent with A B."
513+ assert acc_dtype == torch .float32 , "INT truncation is not yet supported."
509514
510515 # 1D launch kernel where each block gets its own program.
511516 def grid (META ):
0 commit comments