Skip to content

Commit 3655518

Browse files
adjust int8 triton to enable msb/lsb truncation
Signed-off-by: cliu-us <[email protected]>
1 parent 9e2dc3e commit 3655518

File tree

2 files changed

+40
-14
lines changed

2 files changed

+40
-14
lines changed

fms_mo/custom_ext_kernels/triton_kernels.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def matmul_kernel(
101101
stride_cm,
102102
stride_cn,
103103
chunk_trun_bits,
104+
max_acc_bits,
104105
truncate_then_accumulate,
105106
# Meta-parameters
106107
BLOCK_SIZE_M: tl.constexpr,
@@ -212,6 +213,7 @@ def imatmul_kernel(
212213
stride_cm,
213214
stride_cn,
214215
chunk_trun_bits,
216+
max_acc_bits,
215217
truncate_then_accumulate,
216218
# Meta-parameters
217219
BLOCK_SIZE_M: tl.constexpr,
@@ -220,8 +222,8 @@ def imatmul_kernel(
220222
GROUP_SIZE_M: tl.constexpr,
221223
ACTIVATION: tl.constexpr,
222224
):
223-
"""Kernel for computing the INT matmul C = A x B that include LSB truncation. A and B should be
224-
INT8, C should be INT32. (Pretty much the same code as float version.)
225+
"""Kernel for computing the INT matmul D = A x B + C that include LSB truncation and MSB
226+
clamping. A and B should be INT8, C/D should be INT32. (similar to the float version.)
225227
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
226228
Args:
227229
chunk_trun_bits (int): number of LSBs to truncate/round.
@@ -238,14 +240,20 @@ def imatmul_kernel(
238240

239241
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
240242
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
243+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
244+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
241245
offs_k = tl.arange(0, BLOCK_SIZE_K)
242246
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
243247
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
248+
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
249+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
244250

245-
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)
246-
## ------ prepare LSB rounding/truncation masks -------
251+
# accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)
252+
accumulator = tl.load(c_ptrs, mask=c_mask, other=0.0)
253+
## ------ prepare MSB/LSB rounding/truncation masks -------
247254
round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0
248-
# msb_mask = 0x00FFFFFF # only needed when simulating truncation on MSB
255+
acc_min = -(1 << (max_acc_bits - 1))
256+
acc_max = -acc_min - 1
249257
## ---------------------------------------------------------
250258

251259
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
@@ -256,7 +264,11 @@ def imatmul_kernel(
256264
else:
257265
accumulator_inner = tl.dot(a, b, accumulator, input_precision="ieee")
258266

259-
## ------ add chunky LSB rounding/masking --------
267+
## ------ MSB truncation by clamp, chunky LSB truncation by rounding/masking --------
268+
if max_acc_bits < 32:
269+
accumulator_inner = tl.maximum(
270+
tl.minimum(accumulator_inner, acc_max), acc_min
271+
)
260272
if chunk_trun_bits != 0:
261273
accumulator_inner = (accumulator_inner + round_bit) >> chunk_trun_bits
262274
accumulator_inner = accumulator_inner << chunk_trun_bits
@@ -275,8 +287,6 @@ def imatmul_kernel(
275287

276288
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
277289
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
278-
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
279-
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
280290
tl.store(c_ptrs, c, mask=c_mask)
281291

282292

@@ -300,6 +310,7 @@ def matmul_kernel_DABC(
300310
stride_cm,
301311
stride_cn,
302312
chunk_trun_bits,
313+
max_acc_bits,
303314
truncate_then_accumulate,
304315
# Meta-parameters
305316
BLOCK_SIZE_M: tl.constexpr,
@@ -421,6 +432,7 @@ def tl_matmul_chunk_truncate(
421432
activation="",
422433
chunk_trun_bits=0,
423434
chunk_size=16,
435+
max_acc_bits=32,
424436
truncate_then_accumulate=True,
425437
cast_output_to_input_dtype=None,
426438
):
@@ -434,6 +446,9 @@ def tl_matmul_chunk_truncate(
434446
activation (str, optional): activation func to be fused, see relu example.
435447
chunk_trun_bits (int, optional): number of LSBs to be truncated/rounded.
436448
chunk_size (int, optional): BLOCK_SIZE_K, some HW has specific chunk size. must >= 16.
449+
max_acc_bits (int, optional): num of bits for the accumulator, e.g. if INT24 is used, will
450+
clamp each chunk of a*b to [-2**23-1, 2**23].
451+
(assuming no inf when overflow)
437452
truncate_then_accumulate (bool, optional): if True, c = truncate(a*b) + c, otherwise
438453
c = truncate(a*b+c)
439454
cast_output_to_input_dtype (bool, optional): accumulator has higher prec than input, usually
@@ -472,9 +487,9 @@ def isPowerofTwo(x):
472487

473488
# because min k (chunk size in this case) for fp16/bf16 is 16, if smaller is needed, we could
474489
# insert 0s in between elements, e.g. pad [m,k] -> [m,2k], [k,n]->[2k,n], out=[m,n] unchanged.
475-
# Do not support INT8 for now.
476490
if chunk_size == 8 and a.dtype in [
477491
torch.float8_e4m3fn,
492+
torch.int8,
478493
torch.float16,
479494
torch.bfloat16,
480495
]:
@@ -515,7 +530,7 @@ def isPowerofTwo(x):
515530
c_org_dtype = c.dtype
516531
c = c.to(acc_dtype)
517532
assert c.shape[0] == M and c.shape[1] == N, "C shape is inconsistent with A B."
518-
assert acc_dtype == torch.float32, "INT truncation is not yet supported."
533+
# assert acc_dtype == torch.float32, "INT truncation is not yet supported."
519534

520535
# 1D launch kernel where each block gets its own program.
521536
def grid(META):
@@ -556,6 +571,7 @@ def grid(META):
556571
c.stride(0),
557572
c.stride(1),
558573
chunk_trun_bits=chunk_trun_bits,
574+
max_acc_bits=max_acc_bits,
559575
truncate_then_accumulate=truncate_then_accumulate,
560576
ACTIVATION=activation,
561577
**kernel_config, # if using auto-tune, comment this line out.

tests/triton_kernels/test_triton_mm.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,23 @@ def test_triton_matmul_int8(mkn):
9494
torch_output = torch.matmul(a.to(torch.float), b.to(torch.float))
9595
# cast tl_matmul results to float because torch.norm only supports float
9696
tl_output_no_trun = tl_matmul(a, b).to(torch.float)
97+
# check LSB truncation effect
9798
tl_output_trun_8b = tl_matmul(a, b, chunk_trun_bits=8).to(torch.float)
99+
# check MSB truncation effect
100+
# max(1 int8 * 1 int8) ~ 2^17 -> each chunk acc 32 elem, possible max ~ 2^22
101+
# -> truncate to 18b -> should see large err than LSB-only case
102+
tl_output_trun_18b8b = tl_matmul(a, b, max_acc_bits=18, chunk_trun_bits=8).to(
103+
torch.float
104+
)
98105

99-
diff_no_trun = torch_output - tl_output_no_trun
100-
diff_trun_8b = torch_output - tl_output_trun_8b
106+
ref = torch.norm(torch_output)
107+
rel_err_no_trun = torch.norm(torch_output - tl_output_no_trun) / ref
108+
rel_err_trun_8b = torch.norm(torch_output - tl_output_trun_8b) / ref
109+
rel_err_trun_18b8b = torch.norm(torch_output - tl_output_trun_18b8b) / ref
101110

102-
assert torch.norm(diff_no_trun) / torch.norm(torch_output) < 1e-5
103-
assert torch.norm(diff_trun_8b) / torch.norm(torch_output) < 1e-2
111+
assert rel_err_no_trun < 1e-5
112+
assert rel_err_trun_8b < 1e-2
113+
assert rel_err_trun_18b8b < 1e-2
104114

105115

106116
@pytest.mark.parametrize("feat_in_out", [(64, 128), (256, 1024), (1024, 4096)])

0 commit comments

Comments
 (0)