Skip to content

Commit 578759f

Browse files
Merge pull request #120 from chichun-charlie-liu/int-triton-kernel-adj
feat: adjust int8 triton to enable msb/lsb truncation
2 parents 418f682 + 4a1201e commit 578759f

File tree

10 files changed

+373
-129
lines changed

10 files changed

+373
-129
lines changed

fms_mo/custom_ext_kernels/triton_kernels.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def matmul_kernel(
101101
stride_cm,
102102
stride_cn,
103103
chunk_trun_bits,
104+
max_acc_bits, # pylint: disable=unused-argument
104105
truncate_then_accumulate,
105106
# Meta-parameters
106107
BLOCK_SIZE_M: tl.constexpr,
@@ -212,6 +213,7 @@ def imatmul_kernel(
212213
stride_cm,
213214
stride_cn,
214215
chunk_trun_bits,
216+
max_acc_bits,
215217
truncate_then_accumulate,
216218
# Meta-parameters
217219
BLOCK_SIZE_M: tl.constexpr,
@@ -220,8 +222,8 @@ def imatmul_kernel(
220222
GROUP_SIZE_M: tl.constexpr,
221223
ACTIVATION: tl.constexpr,
222224
):
223-
"""Kernel for computing the INT matmul C = A x B that include LSB truncation. A and B should be
224-
INT8, C should be INT32. (Pretty much the same code as float version.)
225+
"""Kernel for computing the INT matmul D = A x B + C that include LSB truncation and MSB
226+
clamping. A and B should be INT8, C/D should be INT32. (similar to the float version.)
225227
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
226228
Args:
227229
chunk_trun_bits (int): number of LSBs to truncate/round.
@@ -238,14 +240,20 @@ def imatmul_kernel(
238240

239241
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
240242
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
243+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
244+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
241245
offs_k = tl.arange(0, BLOCK_SIZE_K)
242246
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
243247
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
248+
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
249+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
244250

245-
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)
246-
## ------ prepare LSB rounding/truncation masks -------
251+
# accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)
252+
accumulator = tl.load(c_ptrs, mask=c_mask, other=0.0)
253+
## ------ prepare MSB/LSB rounding/truncation masks -------
247254
round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0
248-
# msb_mask = 0x00FFFFFF # only needed when simulating truncation on MSB
255+
acc_min = -(1 << (max_acc_bits - 1))
256+
acc_max = -acc_min - 1
249257
## ---------------------------------------------------------
250258

251259
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
@@ -256,7 +264,12 @@ def imatmul_kernel(
256264
else:
257265
accumulator_inner = tl.dot(a, b, accumulator, input_precision="ieee")
258266

259-
## ------ add chunky LSB rounding/masking --------
267+
## ------ INT MSB truncation is simulated by clamping,
268+
# "special" INT LSB truncation by right and left shift --------
269+
if max_acc_bits < 32:
270+
accumulator_inner = tl.maximum(
271+
tl.minimum(accumulator_inner, acc_max), acc_min
272+
)
260273
if chunk_trun_bits != 0:
261274
accumulator_inner = (accumulator_inner + round_bit) >> chunk_trun_bits
262275
accumulator_inner = accumulator_inner << chunk_trun_bits
@@ -275,8 +288,6 @@ def imatmul_kernel(
275288

276289
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
277290
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
278-
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
279-
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
280291
tl.store(c_ptrs, c, mask=c_mask)
281292

282293

@@ -300,6 +311,7 @@ def matmul_kernel_DABC(
300311
stride_cm,
301312
stride_cn,
302313
chunk_trun_bits,
314+
max_acc_bits, # pylint: disable=unused-argument
303315
truncate_then_accumulate,
304316
# Meta-parameters
305317
BLOCK_SIZE_M: tl.constexpr,
@@ -421,6 +433,7 @@ def tl_matmul_chunk_truncate(
421433
activation="",
422434
chunk_trun_bits=0,
423435
chunk_size=16,
436+
max_acc_bits=32,
424437
truncate_then_accumulate=True,
425438
cast_output_to_input_dtype=None,
426439
):
@@ -434,6 +447,9 @@ def tl_matmul_chunk_truncate(
434447
activation (str, optional): activation func to be fused, see relu example.
435448
chunk_trun_bits (int, optional): number of LSBs to be truncated/rounded.
436449
chunk_size (int, optional): BLOCK_SIZE_K, some HW has specific chunk size. must >= 16.
450+
max_acc_bits (int, optional): num of bits for the accumulator, e.g. if INT24 is used, will
451+
clamp each chunk of a*b to [-2**23-1, 2**23].
452+
(assuming no inf when overflow)
437453
truncate_then_accumulate (bool, optional): if True, c = truncate(a*b) + c, otherwise
438454
c = truncate(a*b+c)
439455
cast_output_to_input_dtype (bool, optional): accumulator has higher prec than input, usually
@@ -472,9 +488,9 @@ def isPowerofTwo(x):
472488

473489
# because min k (chunk size in this case) for fp16/bf16 is 16, if smaller is needed, we could
474490
# insert 0s in between elements, e.g. pad [m,k] -> [m,2k], [k,n]->[2k,n], out=[m,n] unchanged.
475-
# Do not support INT8 for now.
476491
if chunk_size == 8 and a.dtype in [
477492
torch.float8_e4m3fn,
493+
torch.int8,
478494
torch.float16,
479495
torch.bfloat16,
480496
]:
@@ -515,7 +531,6 @@ def isPowerofTwo(x):
515531
c_org_dtype = c.dtype
516532
c = c.to(acc_dtype)
517533
assert c.shape[0] == M and c.shape[1] == N, "C shape is inconsistent with A B."
518-
assert acc_dtype == torch.float32, "INT truncation is not yet supported."
519534

520535
# 1D launch kernel where each block gets its own program.
521536
def grid(META):
@@ -556,6 +571,7 @@ def grid(META):
556571
c.stride(0),
557572
c.stride(1),
558573
chunk_trun_bits=chunk_trun_bits,
574+
max_acc_bits=max_acc_bits,
559575
truncate_then_accumulate=truncate_then_accumulate,
560576
ACTIVATION=activation,
561577
**kernel_config, # if using auto-tune, comment this line out.

fms_mo/custom_ext_kernels/utils.py

Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -633,14 +633,17 @@ def exv2_i4f16_fxinputs_abstract(
633633

634634

635635
def imatmul_ops_reg(
636-
useCUTLASS=True, mm_func=torch.matmul, AB_dtype=torch.float, D_dtype=torch.float
636+
useINTkernel="triton",
637+
mm_func=torch.matmul,
638+
AB_dtype=torch.float,
639+
D_dtype=torch.float,
637640
):
638641
"""This function will register a dummy Q_imatmul Op for better "graph representation".
639642
Args:
640-
useCUTLASS: bool. choose to use a) real INT matmul using cutlass kernel or b) "simulated"
641-
imatmul using torch.matmul.
643+
useINTkernel: str|bool. ["cutlass", "triton", False]. choose to use a) real INT matmul, e.g.
644+
cutlass or triton kernel or b) "simulated" imatmul using torch.matmul.
642645
For b), could use D_dtype to select fp16 or fp32 accumulation
643-
mm_func: matmul func to be used when useCUTLASS is True, should be a real callable kernel
646+
mm_func: matmul func to be used when useINTkernel is True, should be a real callable kernel
644647
from cutlass, but for debug purpose, could use torch.matmul as well.
645648
AB_dtype: datatype for input tensors
646649
D_dtype: datatype for accumulation and output tensor
@@ -697,10 +700,10 @@ def imatmul(m1, m2):
697700
tar_shape = tuple(m1.shape[:-1]) + (m2.shape[1],)
698701
m1 = m1.view(re_shape)
699702

700-
if useCUTLASS:
703+
if useINTkernel in ["triton", "cutlass"]:
701704
assert (
702705
m1.dtype == torch.int8 and m2.dtype == torch.int8
703-
), "When using cutlass int matmul, inputs must be 2D INT8"
706+
), "When using int matmul, inputs must be 2D and INT8."
704707
return mm_func(m1, m2).reshape(tar_shape)
705708

706709
outf32_or_f16 = torch.empty(
@@ -759,7 +762,7 @@ def q_iaddmm_dq(bias, m1, m2, scale_i, zp_i, scale_w):
759762
assert m2.dtype == torch.int8, f"weight tensor is of incorrect dtype {m2.dtype}"
760763
m1 = torch.clamp((m1 / scale_i + zp_i - 128).round(), -128, 127).to(torch.int8)
761764

762-
if useCUTLASS:
765+
if useINTkernel:
763766
mm_i32 = mm_func(m1, m2)
764767
else:
765768
outf32_or_f16 = torch.empty(
@@ -856,6 +859,64 @@ def lower_qmodel_cutlass(
856859
return mod
857860

858861

862+
def lower_qmodel_triton(
863+
model: torch.nn.Module,
864+
use_dyn_max_act=False,
865+
max_acc_bits=32,
866+
num_lsb_to_truncate=0,
867+
chunk_size=32,
868+
):
869+
"""
870+
Examplar GPU lowering function using triton. Only swap Qlinears in transformers, nothing else.
871+
Triton kernel can be used to:
872+
1. test INT8 or FP8 HW performance (kernel is not optimized)
873+
2. simulate MSB/LSB truncation effect
874+
875+
Args:
876+
model: nn.Module. should be a fms_mo Qmodel, will do inplace layer swapping, no deepcopy
877+
use_dyn_max_act: bool or int, can be False or -1 for per-token, or -2 for perCh. will use
878+
dynamic max quantizer for activation if not False.
879+
max_acc_bits: max bits for accumulator, typically FP32 for all FP matmuls and INT32 for all
880+
INT matmuls. But some HW could use fewer bits to trade-off power
881+
efficiency at the expense of higher chance of accumulation "overflow".
882+
For example, an INT24 accumulator can only hold values ranged from -2^23 to
883+
2^23 -1, as opposed to typical range -2^31 to -2^31 -1.
884+
num_lsb_to_truncate: number of bits to truncate from LSB side. For example, given fp32 is
885+
s1e8m23, if we choose to truncate 13 mantissa bits from right most side,
886+
i.e. LSB, the resulting number will be s1e8m10, which is TF32.
887+
chunk_size: given a matmul of (m, k) @ (k, n), the inner product will be "accumulated" along
888+
k-dim. Since the entire matrix will be partitioned into smaller tiles when being
889+
computed, accumulator will only add a certain num of elements in one shot. This
890+
"chunk size" in k-dim will affect the overflow/underflow of accumulator.
891+
"""
892+
# Third Party
893+
from torch.ao.quantization.utils import _parent_name
894+
895+
# Local
896+
from fms_mo.modules.linear import QLinear, QLinearINT8Deploy
897+
898+
for name, m in model.named_modules():
899+
if not isinstance(m, QLinear):
900+
continue
901+
parent_name, module_name = _parent_name(name)
902+
parent_mod = model.get_submodule(parent_name)
903+
qmod = getattr(parent_mod, module_name)
904+
setattr(
905+
parent_mod,
906+
module_name,
907+
QLinearINT8Deploy.from_fms_mo(
908+
qmod,
909+
use_int_kernel="triton",
910+
use_dynamic_max_act_Qfunc=use_dyn_max_act,
911+
max_acc_bits=max_acc_bits,
912+
truncate_lsb=num_lsb_to_truncate,
913+
chunk_size=chunk_size,
914+
),
915+
)
916+
917+
logger.info(f"\nModel lowering with triton kernel is done.\n{model}")
918+
919+
859920
### -------------------------------------------------------------
860921
# GPTQ tensor packing functions for Exllama kernel
861922
### -------------------------------------------------------------

fms_mo/dq.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from fms_mo import qconfig_init, qmodel_prep
3939
from fms_mo.fx.utils import model_size_Wb
4040
from fms_mo.quant.ptq import (
41-
calibration_llm_1GPU,
41+
calibration_llm_1GPU_v2,
4242
dq_llm,
4343
get_act_scales,
4444
get_act_scales_1gpu,
@@ -232,9 +232,9 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
232232
if qcfg["qmodel_calibration_new"] > 0:
233233
logger.info("Starting to calibrate activation clip_val")
234234
if qcfg["large_model"]:
235-
calibration_llm_1GPU(qcfg, model, dq_dataloader)
235+
calibration_llm_1GPU_v2(qcfg, model, dq_dataloader)
236236
else:
237-
model.to("cuda:0")
237+
model.to("cuda")
238238
pbar = tqdm(
239239
dq_dataloader,
240240
desc=" calibration after applying smoothq scale and before inference",
@@ -263,7 +263,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
263263
test_dataset = load_from_disk(data_args.test_data_path)
264264
test_dataset = test_dataset.with_format("torch")
265265
elif len(pt_files) > 0:
266-
test_dataset = torch.load(pt_files[0])
266+
test_dataset = torch.load(pt_files[0], weights_only=False)
267267

268268
logger.info(f"Model for evaluation: {model}")
269269
if qcfg["large_model"]:
@@ -272,7 +272,8 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
272272
model.to(torch.device("cuda:0"))
273273
n_samples = int(test_dataset.input_ids.shape[1] / block_size)
274274
evaluator = Evaluator(test_dataset, "cuda", n_samples=n_samples)
275-
ppl = evaluator.evaluate(model, block_size=block_size)
275+
with patch_torch_bmm(qcfg):
276+
ppl = evaluator.evaluate(model, block_size=block_size)
276277
logger.info(f"Model perplexity: {ppl}")
277278
logger.info("-" * 50)
278279
logger.info("Finished evaluation")

0 commit comments

Comments
 (0)