Skip to content

Commit cf39d3d

Browse files
SageAttn method is not ideal to verify accumulator precision, change descriptions to avoid confusion
Signed-off-by: cliu-us <[email protected]>
1 parent a4ea093 commit cf39d3d

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

fms_mo/custom_ext_kernels/triton_kernels.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -376,10 +376,10 @@ def matmul_kernel_DABC(
376376
else:
377377
accumulator_inner = tl.dot(a, b, accumulator, input_precision="ieee")
378378
# tl.dot() default is using TF32 approximation, not good enough for LSB truncation exp
379-
# NOTE: tl.dot(a, b, c) should use one single CUDA mma instruction to handle "c = a*b+c". If
380-
# this mma instruction uses "reduced-precision" under the hood, not only a*b will
381-
# be accumulated in that precision, c most likely will be cast to that "lower"
382-
# precision first, hence, will lose some precision!
379+
# NOTE: tl.dot(a, b, c) should correspond to a CUDA mma instruction, typically "c = a*b+c".
380+
# If this mma instruction uses "reduced-precision" under the hood, not only a*b will
381+
# be accumulated in that precision, there's a chance c will be cast to that "lower"
382+
# precision as well, hence, could lose some precision!
383383

384384
## ------ add chunky LSB rounding/masking --------
385385
if chunk_trun_bits > 0:

0 commit comments

Comments
 (0)