Skip to content

Commit dc2ad5d

Browse files
authored
Merge pull request #82 from chichun-charlie-liu/fp24_acc_trun_chunk8
feat: triton matmul kernel adjusted, now is closer to HW behavior
2 parents 4e59ef1 + dbd540d commit dc2ad5d

File tree

2 files changed

+240
-22
lines changed

2 files changed

+240
-22
lines changed

fms_mo/custom_ext_kernels/triton_kernels.py

Lines changed: 199 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def matmul_kernel(
101101
stride_cm,
102102
stride_cn,
103103
chunk_trun_bits,
104+
truncate_then_accumulate,
104105
# Meta-parameters
105106
BLOCK_SIZE_M: tl.constexpr,
106107
BLOCK_SIZE_N: tl.constexpr,
@@ -159,15 +160,20 @@ def matmul_kernel(
159160
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
160161
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
161162
# We accumulate along the K dimension.
162-
accumulator = tl.dot(a, b, accumulator, input_precision="ieee")
163+
if truncate_then_accumulate:
164+
accumulator_inner = tl.dot(a, b, input_precision="ieee")
165+
else:
166+
accumulator_inner = tl.dot(a, b, accumulator, input_precision="ieee")
163167
# tl.dot() default is using TF32 approximation, not good enough for LSB truncation exp
164168

165169
## ------ add chunky LSB rounding/masking --------
166170
if chunk_trun_bits > 0:
167-
accumulator = libdevice.uint_as_float(
168-
(libdevice.float_as_uint(accumulator) + round_bit) & trun_mask
169-
)
171+
accumulator_inner = round_and_trun(accumulator_inner, round_bit, trun_mask)
170172
## ---------------------------------------------------------
173+
if truncate_then_accumulate:
174+
accumulator += accumulator_inner
175+
else:
176+
accumulator = accumulator_inner
171177

172178
# Advance the ptrs to the next K block.
173179
a_ptrs += BLOCK_SIZE_K * stride_ak
@@ -206,6 +212,7 @@ def imatmul_kernel(
206212
stride_cm,
207213
stride_cn,
208214
chunk_trun_bits,
215+
truncate_then_accumulate,
209216
# Meta-parameters
210217
BLOCK_SIZE_M: tl.constexpr,
211218
BLOCK_SIZE_N: tl.constexpr,
@@ -244,13 +251,20 @@ def imatmul_kernel(
244251
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
245252
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
246253
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
247-
accumulator = tl.dot(a, b, accumulator, input_precision="ieee")
254+
if truncate_then_accumulate:
255+
accumulator_inner = tl.dot(a, b, input_precision="ieee")
256+
else:
257+
accumulator_inner = tl.dot(a, b, accumulator, input_precision="ieee")
248258

249259
## ------ add chunky LSB rounding/masking --------
250260
if chunk_trun_bits != 0:
251-
accumulator = (accumulator + round_bit) >> chunk_trun_bits
252-
accumulator = accumulator << chunk_trun_bits
261+
accumulator_inner = (accumulator_inner + round_bit) >> chunk_trun_bits
262+
accumulator_inner = accumulator_inner << chunk_trun_bits
253263
## ---------------------------------------------------------
264+
if truncate_then_accumulate:
265+
accumulator += accumulator_inner
266+
else:
267+
accumulator = accumulator_inner
254268

255269
a_ptrs += BLOCK_SIZE_K * stride_ak
256270
b_ptrs += BLOCK_SIZE_K * stride_bk
@@ -266,29 +280,162 @@ def imatmul_kernel(
266280
tl.store(c_ptrs, c, mask=c_mask)
267281

268282

283+
@triton.jit
284+
def matmul_kernel_DABC(
285+
# Pointers to matrices
286+
a_ptr,
287+
b_ptr,
288+
c_ptr,
289+
# Matrix dimensions
290+
M,
291+
N,
292+
K,
293+
# The stride variables represent how much to increase the ptr by when moving by 1
294+
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
295+
# by to get the element one row down (A has M rows).
296+
stride_am,
297+
stride_ak,
298+
stride_bk,
299+
stride_bn,
300+
stride_cm,
301+
stride_cn,
302+
chunk_trun_bits,
303+
truncate_then_accumulate,
304+
# Meta-parameters
305+
BLOCK_SIZE_M: tl.constexpr,
306+
BLOCK_SIZE_N: tl.constexpr,
307+
BLOCK_SIZE_K: tl.constexpr,
308+
GROUP_SIZE_M: tl.constexpr,
309+
ACTIVATION: tl.constexpr,
310+
):
311+
"""Kernel for computing the matmul D = A x B + C that include LSB truncation.
312+
A has shape (M, K), B has shape (K, N) and C/D has shape (M, N).
313+
NOTE:
314+
C should be consistent with accumulator dtype, e.g. fp8xfp8 -> fp32.
315+
*D ptr is supposed to be the same as C ptr, no need to provide D as arg
316+
**we can be used C to verify unintended truncation by CUDA as well.
317+
Args:
318+
chunk_trun_bits (int): number of LSB to truncate/round. [0 to 23]
319+
"""
320+
# -----------------------------------------------------------
321+
# Map program ids `pid` to the block of C it should compute.
322+
# This is done in a grouped ordering to promote L2 data reuse.
323+
# See above `L2 Cache Optimizations` section for details.
324+
pid = tl.program_id(axis=0)
325+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
326+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
327+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
328+
group_id = pid // num_pid_in_group
329+
first_pid_m = group_id * GROUP_SIZE_M
330+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
331+
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
332+
pid_n = (pid % num_pid_in_group) // group_size_m
333+
334+
# ----------------------------------------------------------
335+
# Create pointers for the first blocks of A and B.
336+
# We will advance this pointer as we move in the K direction
337+
# and accumulate
338+
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
339+
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
340+
# See above `Pointer Arithmetic` section for details
341+
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
342+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
343+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
344+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
345+
offs_k = tl.arange(0, BLOCK_SIZE_K)
346+
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
347+
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
348+
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
349+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
350+
351+
# -----------------------------------------------------------
352+
# Iterate to compute a block of the C matrix.
353+
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
354+
# of fp32 values for higher accuracy, i.e. C should have been cast to fp32 already
355+
accumulator = tl.load(c_ptrs, mask=c_mask, other=0.0)
356+
## ------ prepare LSB rounding/truncation masks -------
357+
# NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b
358+
# e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
359+
# 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080
360+
trun_mask = tl.cast((0xFFFFFFFF >> chunk_trun_bits) << chunk_trun_bits, tl.uint32)
361+
round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0
362+
## ---------------------------------------------------------
363+
364+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
365+
# Load the next block of A, B, and C, generate a mask by checking the K dimension.
366+
# If it is out of bounds, set it to 0.
367+
# D = truncation(A*B) + C
368+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
369+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
370+
# We accumulate along the K dimension. but apply truncation on local A*B first
371+
if truncate_then_accumulate:
372+
accumulator_inner = tl.dot(a, b, input_precision="ieee")
373+
else:
374+
accumulator_inner = tl.dot(a, b, accumulator, input_precision="ieee")
375+
# tl.dot() default is using TF32 approximation, not good enough for LSB truncation exp
376+
# NOTE: tl.dot(a, b, c) should correspond to a CUDA mma instruction, typically "c = a*b+c".
377+
# If this mma instruction uses "reduced-precision" under the hood, not only a*b will
378+
# be accumulated in that precision, there's a chance c will be cast to that "lower"
379+
# precision as well, hence, could lose some precision!
380+
381+
## ------ add chunky LSB rounding/masking --------
382+
if chunk_trun_bits > 0:
383+
accumulator_inner = round_and_trun(accumulator_inner, round_bit, trun_mask)
384+
## ---------------------------------------------------------
385+
if truncate_then_accumulate:
386+
accumulator += accumulator_inner
387+
else:
388+
accumulator = accumulator_inner
389+
390+
# Advance the ptrs to the next K block.
391+
a_ptrs += BLOCK_SIZE_K * stride_ak
392+
b_ptrs += BLOCK_SIZE_K * stride_bk
393+
# You can fuse arbitrary activation functions here
394+
# while the accumulator is still in FP32!
395+
if ACTIVATION == "leaky_relu":
396+
accumulator = leaky_relu(accumulator)
397+
398+
d = accumulator # do not cast to (tl.float16) just yet
399+
400+
# -----------------------------------------------------------
401+
# Write back the block of the output to matrix "C" with masks.
402+
tl.store(c_ptrs, d, mask=c_mask)
403+
404+
269405
@triton.jit
270406
def leaky_relu(x):
271407
"""Activation function that could be fused into matmul kernel"""
272408
return tl.where(x >= 0, x, 0.01 * x)
273409

274410

411+
@triton.jit
412+
def round_and_trun(x, round_bit, trun_mask):
413+
"""Round and truncate (usually for accumulator)."""
414+
return libdevice.uint_as_float((libdevice.float_as_uint(x) + round_bit) & trun_mask)
415+
416+
275417
def tl_matmul_chunk_truncate(
276418
a,
277419
b,
420+
c=None,
278421
activation="",
279422
chunk_trun_bits=0,
280423
chunk_size=16,
424+
truncate_then_accumulate=True,
281425
cast_output_to_input_dtype=None,
282426
):
283427
"""Triton matmul for HW behavior simulation. Supports float and int8.
284-
a. variable chunk size (i.e., BLOCK_SIZE_K)
285-
b. LSB truncation, must <23 if using float.
428+
i. variable chunk size (i.e., BLOCK_SIZE_K)
429+
ii. LSB truncation, must <23 if using float.
430+
iii. assume D = A*B + C, where C is optional. If C exists, it will be updated inplace.
286431
287432
Args:
288433
a, b: input tensors. FloatX, X in [32, 16, 8] or INT8.
289434
activation (str, optional): activation func to be fused, see relu example.
290435
chunk_trun_bits (int, optional): number of LSBs to be truncated/rounded.
291436
chunk_size (int, optional): BLOCK_SIZE_K, some HW has specific chunk size. must >= 16.
437+
truncate_then_accumulate (bool, optional): if True, c = truncate(a*b) + c, otherwise
438+
c = truncate(a*b+c)
292439
cast_output_to_input_dtype (bool, optional): accumulator has higher prec than input, usually
293440
FP32 or INT32. by default we cast the final
294441
output to the same dtype as input for non-8bits.
@@ -300,6 +447,7 @@ def tl_matmul_chunk_truncate(
300447
use empirical way to determine BLOCK sizes, may not be optimal. But need to avoid autotune for
301448
real model inference. otherwise auto-tune will be triggered in every forward call.
302449
"""
450+
303451
# Check constraints.
304452
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
305453
assert a.is_contiguous(), "Matrix A must be contiguous"
@@ -314,28 +462,60 @@ def tl_matmul_chunk_truncate(
314462
if cuda_cc[0] >= 9 or cuda_cc == (8, 9):
315463
allowed_dtypes += DTYPE_F8
316464
assert a.dtype in allowed_dtypes, "Input dtype is not supported"
317-
M, K = a.shape
318-
K, N = b.shape
319465

320466
# Allocates output, always accumulate in FP32 (if floats) or INT32 then cast
321467
def isPowerofTwo(x):
322468
"""triton-specific limitation: block size needs to be power of 2."""
323469
return (x & (x - 1)) == 0
324470

325471
min_chunk_size = 32 if a.dtype in DTYPE_8BIT else 16
326-
if isPowerofTwo(chunk_size):
327-
chunk_size = max(chunk_size, min_chunk_size)
328-
else:
472+
473+
# because min k (chunk size in this case) for fp16/bf16 is 16, if smaller is needed, we could
474+
# insert 0s in between elements, e.g. pad [m,k] -> [m,2k], [k,n]->[2k,n], out=[m,n] unchanged.
475+
# Do not support INT8 for now.
476+
if chunk_size == 8 and a.dtype in [
477+
torch.float8_e4m3fn,
478+
torch.float16,
479+
torch.bfloat16,
480+
]:
481+
exp_ratio = min_chunk_size // chunk_size
482+
a_padded = torch.zeros(
483+
a.shape[0], a.shape[1] * exp_ratio, dtype=a.dtype, device=a.device
484+
)
485+
a_padded[:, ::exp_ratio] = a
486+
a = a_padded
487+
b_padded = torch.zeros(
488+
b.shape[0] * exp_ratio, b.shape[1], dtype=b.dtype, device=b.device
489+
)
490+
b_padded[::exp_ratio, :] = b
491+
b = b_padded
329492
chunk_size = min_chunk_size
493+
else:
494+
chunk_size = (
495+
max(chunk_size, min_chunk_size)
496+
if isPowerofTwo(chunk_size)
497+
else min_chunk_size
498+
)
330499

500+
M, K = a.shape
501+
K, N = b.shape
331502
if a.dtype in DTYPE_I8:
332503
acc_dtype = torch.int32
333504
mm_kernel = imatmul_kernel
334505
else:
335506
acc_dtype = torch.float32
336-
mm_kernel = matmul_kernel
507+
mm_kernel = matmul_kernel if c is None else matmul_kernel_DABC
337508
assert chunk_trun_bits < 23, "FP32 accumulator only has 23 mantissa bits"
338-
c = torch.zeros((M, N), device=a.device, dtype=acc_dtype)
509+
510+
if c is None:
511+
c_org_dtype = a.dtype
512+
c = torch.zeros((M, N), device=a.device, dtype=acc_dtype)
513+
else:
514+
# if C is in fp16, accumulate in fp32 no matter what, decide whether to cast back later
515+
c_org_dtype = c.dtype
516+
c = c.to(acc_dtype)
517+
assert c.shape[0] == M and c.shape[1] == N, "C shape is inconsistent with A B."
518+
assert acc_dtype == torch.float32, "INT truncation is not yet supported."
339519

340520
# 1D launch kernel where each block gets its own program.
341521
def grid(META):
@@ -345,7 +525,7 @@ def grid(META):
345525

346526
if M < 1024 or N < 1024:
347527
kernel_config = {
348-
"BLOCK_SIZE_M": 128,
528+
"BLOCK_SIZE_M": 64,
349529
"BLOCK_SIZE_K": chunk_size,
350530
"BLOCK_SIZE_N": 32,
351531
"GROUP_SIZE_M": 8,
@@ -376,7 +556,8 @@ def grid(META):
376556
c.stride(0),
377557
c.stride(1),
378558
chunk_trun_bits=chunk_trun_bits,
559+
truncate_then_accumulate=truncate_then_accumulate,
379560
ACTIVATION=activation,
380561
**kernel_config, # if using auto-tune, comment this line out.
381562
)
382-
return c.to(a.dtype) if cast_output_to_input_dtype else c
563+
return c.to(c_org_dtype) if cast_output_to_input_dtype else c

0 commit comments

Comments
 (0)