Skip to content

Commit 0e98567

Browse files
Merge pull request #159 from chichun-charlie-liu/triton_aiu_sim
feat: AIU sim for FP8 (DL8/DL16) added to triton kernel
2 parents 952b6d4 + 1bbf139 commit 0e98567

File tree

9 files changed

+357
-108
lines changed

9 files changed

+357
-108
lines changed

fms_mo/custom_ext_kernels/triton_kernels.py

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def matmul_kernel(
114114
stride_cn,
115115
chunk_trun_bits,
116116
max_acc_bits, # pylint: disable=unused-argument
117+
clamp_acc_to_dl16,
117118
truncate_then_accumulate,
118119
# Meta-parameters
119120
BLOCK_SIZE_M: tl.constexpr,
@@ -159,13 +160,8 @@ def matmul_kernel(
159160
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
160161
# of fp32 values for higher accuracy.
161162
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
162-
## ------ prepare LSB rounding/truncation masks -------
163-
# NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b
164-
# e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
165-
# 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080
166-
trun_mask = tl.cast((0xFFFFFFFF >> chunk_trun_bits) << chunk_trun_bits, tl.uint32)
167-
round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0
168-
## ---------------------------------------------------------
163+
## ------ prepare LSB rounding/truncation masks outside the for loop -------
164+
round_bit, trun_mask = round_and_trun_mask(chunk_trun_bits, clamp_acc_to_dl16)
169165

170166
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
171167
# Load the next block of A and B, generate a mask by checking the K dimension.
@@ -180,8 +176,10 @@ def matmul_kernel(
180176
# tl.dot() default is using TF32 approximation, not good enough for LSB truncation exp
181177

182178
## ------ add chunky LSB rounding/masking --------
183-
if chunk_trun_bits > 0:
184-
accumulator_inner = round_and_trun(accumulator_inner, round_bit, trun_mask)
179+
if clamp_acc_to_dl16 or chunk_trun_bits > 0:
180+
accumulator_inner = round_and_trun(
181+
accumulator_inner, round_bit, trun_mask, clamp_acc_to_dl16
182+
)
185183
## ---------------------------------------------------------
186184
if truncate_then_accumulate:
187185
accumulator += accumulator_inner
@@ -226,6 +224,7 @@ def imatmul_kernel(
226224
stride_cn,
227225
chunk_trun_bits,
228226
max_acc_bits,
227+
clamp_acc_to_dl16, # pylint: disable=unused-argument
229228
truncate_then_accumulate,
230229
# Meta-parameters
231230
BLOCK_SIZE_M: tl.constexpr,
@@ -324,6 +323,7 @@ def matmul_kernel_DABC(
324323
stride_cn,
325324
chunk_trun_bits,
326325
max_acc_bits, # pylint: disable=unused-argument
326+
clamp_acc_to_dl16,
327327
truncate_then_accumulate,
328328
# Meta-parameters
329329
BLOCK_SIZE_M: tl.constexpr,
@@ -377,13 +377,8 @@ def matmul_kernel_DABC(
377377
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
378378
# of fp32 values for higher accuracy, i.e. C should have been cast to fp32 already
379379
accumulator = tl.load(c_ptrs, mask=c_mask, other=0.0)
380-
## ------ prepare LSB rounding/truncation masks -------
381-
# NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b
382-
# e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
383-
# 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080
384-
trun_mask = tl.cast((0xFFFFFFFF >> chunk_trun_bits) << chunk_trun_bits, tl.uint32)
385-
round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0
386-
## ---------------------------------------------------------
380+
## ------ prepare LSB rounding/truncation masks outside the for loop -------
381+
round_bit, trun_mask = round_and_trun_mask(chunk_trun_bits, clamp_acc_to_dl16)
387382

388383
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
389384
# Load the next block of A, B, and C, generate a mask by checking the K dimension.
@@ -403,8 +398,10 @@ def matmul_kernel_DABC(
403398
# precision as well, hence, could lose some precision!
404399

405400
## ------ add chunky LSB rounding/masking --------
406-
if chunk_trun_bits > 0:
407-
accumulator_inner = round_and_trun(accumulator_inner, round_bit, trun_mask)
401+
if clamp_acc_to_dl16 or chunk_trun_bits > 0:
402+
accumulator_inner = round_and_trun(
403+
accumulator_inner, round_bit, trun_mask, clamp_acc_to_dl16
404+
)
408405
## ---------------------------------------------------------
409406
if truncate_then_accumulate:
410407
accumulator += accumulator_inner
@@ -433,9 +430,39 @@ def leaky_relu(x):
433430

434431

435432
@triton.jit
436-
def round_and_trun(x, round_bit, trun_mask):
433+
def round_and_trun_mask(chunk_trun_bits, clamp_acc_to_dl16):
434+
"""
435+
Rounding and LSB truncation masks only need to be generated once.
436+
These mask will be applied on "inner" accumulator, which is alway FP32 (e8m23). We may truncate
437+
up to 23b for mantissa. If DL16/DL8, pay attention to exponent bias.
438+
Examples: 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
439+
8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080
440+
"""
441+
if clamp_acc_to_dl16:
442+
# DL16 is e6m9, hence, truncate 23 - 9 = 14 bits
443+
chunk_trun_bits = 14
444+
round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0
445+
trun_mask = ~tl.cast((1 << chunk_trun_bits) - 1, tl.uint32)
446+
return round_bit, trun_mask
447+
448+
449+
@triton.jit
450+
def round_and_trun(x, round_bit, trun_mask, clamp_acc_to_dl16):
437451
"""Round and truncate (usually for accumulator)."""
438-
return libdevice.uint_as_float((libdevice.float_as_uint(x) + round_bit) & trun_mask)
452+
x = libdevice.uint_as_float((libdevice.float_as_uint(x) + round_bit) & trun_mask)
453+
454+
if clamp_acc_to_dl16:
455+
# clamp to DL16 min/max:
456+
# max = 2^32 * 1.(1111 1111 0)_base2 = 2^32*(2 - 2^-9) = 8581545984.0
457+
# greater than this will become +inf (or -inf)
458+
# min = 2^-31 * 1.(0000 0000 1)_base2 = 2^-31*(1 + 2^-9)> = 4.665707820095122e-10
459+
# smaller than this will become 0
460+
dl16_max = 8581545984.0
461+
dl16_min = 4.665707820095122e-10
462+
x = tl.where(x >= dl16_max, float("inf"), x)
463+
x = tl.where(x <= -dl16_max, float("-inf"), x)
464+
x = tl.where(tl.abs(x) < dl16_min, 0, x)
465+
return x
439466

440467

441468
def tl_matmul_chunk_truncate(
@@ -448,6 +475,7 @@ def tl_matmul_chunk_truncate(
448475
max_acc_bits=32,
449476
truncate_then_accumulate=True,
450477
cast_output_to_input_dtype=None,
478+
clamp_acc_to_dl16=False,
451479
):
452480
"""Triton matmul for HW behavior simulation. Supports float and int8.
453481
i. variable chunk size (i.e., BLOCK_SIZE_K)
@@ -461,7 +489,8 @@ def tl_matmul_chunk_truncate(
461489
chunk_size (int, optional): BLOCK_SIZE_K, some HW has specific chunk size. must >= 16.
462490
max_acc_bits (int, optional): num of bits for the accumulator, e.g. if INT24 is used, will
463491
clamp each chunk of a*b to [-2**23-1, 2**23].
464-
(assuming no inf when overflow)
492+
(only used by INT)
493+
clamp_acc_to_dl16(bool): Only used by FP8, whether to clamp local accumulator (FP32) to DL16
465494
truncate_then_accumulate (bool, optional): if True, c = truncate(a*b) + c, otherwise
466495
c = truncate(a*b+c)
467496
cast_output_to_input_dtype (bool, optional): accumulator has higher prec than input, usually
@@ -473,7 +502,7 @@ def tl_matmul_chunk_truncate(
473502
474503
NOTE:
475504
use empirical way to determine BLOCK sizes, may not be optimal. But need to avoid autotune for
476-
real model inference. otherwise auto-tune will be triggered in every forward call.
505+
real model inference. otherwise auto-tune may be triggered in every forward call.
477506
"""
478507

479508
# Check constraints.
@@ -584,6 +613,7 @@ def grid(META):
584613
c.stride(1),
585614
chunk_trun_bits=chunk_trun_bits,
586615
max_acc_bits=max_acc_bits,
616+
clamp_acc_to_dl16=clamp_acc_to_dl16,
587617
truncate_then_accumulate=truncate_then_accumulate,
588618
ACTIVATION=activation,
589619
**kernel_config, # if using auto-tune, comment this line out.

fms_mo/custom_ext_kernels/utils.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -870,14 +870,16 @@ def lower_qmodel_triton(
870870
model: torch.nn.Module,
871871
use_dyn_max_act=False,
872872
max_acc_bits=32,
873+
clamp_acc_to_dl16=False,
873874
num_lsb_to_truncate=0,
874875
chunk_size=32,
876+
layer_to_exclude=None,
875877
):
876878
"""
877-
Examplar GPU lowering function using triton. Only swap Qlinears in transformers, nothing else.
879+
Examplar GPU lowering function using triton. Only swap Linear/Qlinear in transformers.
878880
Triton kernel can be used to:
879881
1. test INT8 or FP8 HW performance (kernel is not optimized)
880-
2. simulate MSB/LSB truncation effect
882+
2. simulate MSB/LSB truncation effect or special dtype (DL16) accumulation
881883
882884
Args:
883885
model: nn.Module. should be a fms_mo Qmodel, will do inplace layer swapping, no deepcopy
@@ -888,6 +890,8 @@ def lower_qmodel_triton(
888890
efficiency at the expense of higher chance of accumulation "overflow".
889891
For example, an INT24 accumulator can only hold values ranged from -2^23 to
890892
2^23 -1, as opposed to typical range -2^31 to -2^31 -1.
893+
clamp_acc_to_dl16: clamp local accumulator to DL16 (1-6-9) range. To simulate this special
894+
dtype effect on accumulation.
891895
num_lsb_to_truncate: number of bits to truncate from LSB side. For example, given fp32 is
892896
s1e8m23, if we choose to truncate 13 mantissa bits from right most side,
893897
i.e. LSB, the resulting number will be s1e8m10, which is TF32.
@@ -900,25 +904,56 @@ def lower_qmodel_triton(
900904
from torch.ao.quantization.utils import _parent_name
901905

902906
# Local
903-
from fms_mo.modules.linear import QLinear, QLinearINT8Deploy
907+
from fms_mo.modules.linear import LinearFPxAcc, QLinear, QLinearINT8Deploy
908+
909+
# Currently QLinearINT8 has more options in dynamic quantization than LinearFP. Here we resolve
910+
# the differences as a patch solution (will unify the codes in future release)
911+
linFP_dyn_code = (
912+
"per_token"
913+
if use_dyn_max_act in [-1, -2]
914+
else "per_tensor"
915+
if use_dyn_max_act
916+
else False
917+
)
918+
919+
if layer_to_exclude is None:
920+
layer_to_exclude = []
921+
elif isinstance(layer_to_exclude, str):
922+
layer_to_exclude = [
923+
layer_to_exclude,
924+
]
925+
elif not isinstance(layer_to_exclude, (list, tuple)):
926+
raise RuntimeError("layer_to_exclude has to be either str, list, or tuple.")
904927

905928
for name, m in model.named_modules():
906-
if not isinstance(m, QLinear):
929+
if not isinstance(m, (QLinear, torch.nn.Linear)) or name in layer_to_exclude:
907930
continue
908931
parent_name, module_name = _parent_name(name)
909932
parent_mod = model.get_submodule(parent_name)
910-
qmod = getattr(parent_mod, module_name)
911-
setattr(
912-
parent_mod,
913-
module_name,
914-
QLinearINT8Deploy.from_fms_mo(
915-
qmod,
933+
934+
# Only support simulations of 1) QLinear -> INT8, 2) nnLinear->FP8 for now
935+
if isinstance(m, QLinear):
936+
new_lin = QLinearINT8Deploy.from_fms_mo(
937+
m,
916938
use_int_kernel="triton",
917939
use_dynamic_max_act_Qfunc=use_dyn_max_act,
918940
max_acc_bits=max_acc_bits,
919941
truncate_lsb=num_lsb_to_truncate,
920942
chunk_size=chunk_size,
921-
),
943+
)
944+
else:
945+
new_lin = LinearFPxAcc.from_nn(
946+
m,
947+
trun_bits=num_lsb_to_truncate,
948+
chunk_size=chunk_size,
949+
dynamic_fp8=linFP_dyn_code,
950+
clamp_acc_to_dl16=clamp_acc_to_dl16,
951+
)
952+
953+
setattr(
954+
parent_mod,
955+
module_name,
956+
new_lin,
922957
)
923958

924959
logger.info(f"\nModel lowering with triton kernel is done.\n{model}")

fms_mo/dq.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -161,18 +161,18 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
161161
# config layers to skip, smooth scale
162162
config_quantize_smooth_layers(qcfg)
163163

164+
use_dynamo = True
165+
# use dynamo as default unless really needed, False -> fallback to TorchScript tracing
164166
if any(x != 32 for x in attn_bits):
165167
logger.info("Quantize attention bmms or kvcache, will use dynamo for prep")
166168
use_layer_name_pattern_matching = False
167169
qcfg["qlayer_name_pattern"] = []
168170
assert (
169171
qcfg["qlayer_name_pattern"] == []
170172
), "ensure nothing in qlayer_name_pattern when use dynamo"
171-
use_dynamo = True
172173
else:
173174
logger.info("Attention bmms will not be quantized.")
174175
use_layer_name_pattern_matching = True
175-
use_dynamo = False
176176

177177
qcfg["seq_len"] = block_size
178178
qcfg["model"] = model_args.model_name_or_path
@@ -216,17 +216,18 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
216216
act_scales = get_act_scales(model, dq_dataloader, qcfg)
217217
torch.save(act_scales, scale_file)
218218

219-
qmodel_prep(
220-
model,
221-
dq_dataloader,
222-
qcfg,
223-
use_layer_name_pattern_matching=use_layer_name_pattern_matching,
224-
use_dynamo=use_dynamo,
225-
dev=dev,
226-
save_fname="dq",
227-
)
228-
logger.info(f"Quantized model {model}")
229-
logger.info("==" * 20)
219+
if fms_mo_args.aiu_sim_triton != "fp8":
220+
qmodel_prep(
221+
model,
222+
dq_dataloader,
223+
qcfg,
224+
use_layer_name_pattern_matching=use_layer_name_pattern_matching,
225+
use_dynamo=use_dynamo,
226+
dev=dev,
227+
save_fname="dq",
228+
)
229+
logger.info(f"Quantized model {model}")
230+
logger.info("==" * 20)
230231

231232
if qcfg["smoothq"]:
232233
logger.info("Starting to apply smooth scale")
@@ -260,12 +261,15 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
260261
tokenizer.save_pretrained(opt_args.output_dir)
261262

262263
if fms_mo_args.aiu_sim_triton:
264+
# NOTE plz apply correct HW settings here, defaults are not real HW params
263265
lower_qmodel_triton(
264266
model,
265267
use_dyn_max_act=-1 if qcfg["qa_mode"] == "pertokenmax" else False,
266268
max_acc_bits=qcfg.get("max_acc_bits", 32),
267269
num_lsb_to_truncate=qcfg.get("lsb_trun_bits", 0),
268-
chunk_size=qcfg.get("chunk_size", 1024),
270+
chunk_size=qcfg.get("chunk_size", 32), # 1024
271+
clamp_acc_to_dl16=fms_mo_args.aiu_sim_triton == "fp8",
272+
# layer_to_exclude=["lm_head",]
269273
)
270274

271275
if fms_mo_args.eval_ppl:

fms_mo/fx/dynamo_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,14 +1180,20 @@ def cus_backend_model_analyzer(
11801180
if is_transformers:
11811181
# NOTE simplified method to determine 1st/last modules for transformers.
11821182
# will not work if model has multiple parallel heads at the end, e.g. obj det
1183-
def call_seq_hook(mod, *_args, **_kwargs):
1184-
qcfg["mod_call_seq"].append(lut_weight2modname[mod.weight])
1183+
def call_seq_hook(mod, *_args, **kwargs):
1184+
mod_name = kwargs.get("mod_name", lut_weight2modname.get(mod.weight, None))
1185+
if mod_name is None:
1186+
raise RuntimeError("cannot determine module name, plz check model.")
1187+
1188+
qcfg["mod_call_seq"].append(mod_name)
11851189

11861190
h_hooks = []
11871191
qcfg["mod_call_seq"] = []
11881192
for n, m in model.named_modules():
11891193
if isinstance(m, (torch.nn.Linear, torch.nn.Conv2d)):
1190-
h_hooks.append(m.register_forward_hook(call_seq_hook))
1194+
h_hooks.append(
1195+
m.register_forward_hook(partial(call_seq_hook, mod_name=n))
1196+
)
11911197

11921198
with torch.no_grad():
11931199
run_fwd_once(model, sample_inp)

fms_mo/fx/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,14 +461,14 @@ def model_size_Wb(mod, unit="MB", print_to_file=True, show_details=False):
461461
w_mat.numel() * w_mat.element_size()
462462
+ b_mat.numel() * b_mat.element_size()
463463
)
464-
w_dtype = w_mat.dtype
464+
w_dtype = str(w_mat.dtype)
465465
w_shape = w_mat.shape
466466

467467
elif isinstance(w, torch.Tensor):
468468
mem_use = w.numel() * w.element_size()
469469
if hasattr(m, "bias") and m.bias is not None:
470470
mem_use += m.bias.numel() * m.bias.element_size()
471-
w_dtype = w.dtype
471+
w_dtype = str(w.dtype)
472472
w_shape = w.shape
473473

474474
if w_shape:

0 commit comments

Comments
 (0)