Skip to content

Commit b685ea8

Browse files
fix triton DL16 aiu sim with subnorm flushing
Signed-off-by: cliu-us <[email protected]>
1 parent 9925706 commit b685ea8

File tree

6 files changed

+229
-64
lines changed

6 files changed

+229
-64
lines changed

fms_mo/custom_ext_kernels/triton_kernels.py

Lines changed: 64 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -160,13 +160,8 @@ def matmul_kernel(
160160
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
161161
# of fp32 values for higher accuracy.
162162
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
163-
## ------ prepare LSB rounding/truncation masks -------
164-
# NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b
165-
# e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
166-
# 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080
167-
trun_mask = ~tl.cast((1 << chunk_trun_bits) - 1, tl.uint32)
168-
round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0
169-
## ---------------------------------------------------------
163+
## ------ prepare LSB rounding/truncation masks outside the for loop -------
164+
round_bit, trun_mask = round_and_trun_mask(chunk_trun_bits, clamp_acc_to_dl16)
170165

171166
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
172167
# Load the next block of A and B, generate a mask by checking the K dimension.
@@ -181,10 +176,10 @@ def matmul_kernel(
181176
# tl.dot() default is using TF32 approximation, not good enough for LSB truncation exp
182177

183178
## ------ add chunky LSB rounding/masking --------
184-
if chunk_trun_bits > 0:
185-
accumulator_inner = round_and_trun(accumulator_inner, round_bit, trun_mask)
186-
if clamp_acc_to_dl16:
187-
accumulator_inner = fp32_clamp_to_dl16(accumulator_inner)
179+
if clamp_acc_to_dl16 or chunk_trun_bits > 0:
180+
accumulator_inner = round_and_trun(
181+
accumulator_inner, round_bit, trun_mask, clamp_acc_to_dl16
182+
)
188183
## ---------------------------------------------------------
189184
if truncate_then_accumulate:
190185
accumulator += accumulator_inner
@@ -382,13 +377,8 @@ def matmul_kernel_DABC(
382377
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
383378
# of fp32 values for higher accuracy, i.e. C should have been cast to fp32 already
384379
accumulator = tl.load(c_ptrs, mask=c_mask, other=0.0)
385-
## ------ prepare LSB rounding/truncation masks -------
386-
# NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b
387-
# e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
388-
# 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080
389-
trun_mask = ~tl.cast((1 << chunk_trun_bits) - 1, tl.uint32)
390-
round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0
391-
## ---------------------------------------------------------
380+
## ------ prepare LSB rounding/truncation masks outside the for loop -------
381+
round_bit, trun_mask = round_and_trun_mask(chunk_trun_bits, clamp_acc_to_dl16)
392382

393383
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
394384
# Load the next block of A, B, and C, generate a mask by checking the K dimension.
@@ -408,10 +398,10 @@ def matmul_kernel_DABC(
408398
# precision as well, hence, could lose some precision!
409399

410400
## ------ add chunky LSB rounding/masking --------
411-
if chunk_trun_bits > 0:
412-
accumulator_inner = round_and_trun(accumulator_inner, round_bit, trun_mask)
413-
if clamp_acc_to_dl16:
414-
accumulator_inner = fp32_clamp_to_dl16(accumulator_inner)
401+
if clamp_acc_to_dl16 or chunk_trun_bits > 0:
402+
accumulator_inner = round_and_trun(
403+
accumulator_inner, round_bit, trun_mask, clamp_acc_to_dl16
404+
)
415405
## ---------------------------------------------------------
416406
if truncate_then_accumulate:
417407
accumulator += accumulator_inner
@@ -440,34 +430,64 @@ def leaky_relu(x):
440430

441431

442432
@triton.jit
443-
def round_and_trun(x, round_bit, trun_mask):
444-
"""Round and truncate (usually for accumulator)."""
445-
return libdevice.uint_as_float((libdevice.float_as_uint(x) + round_bit) & trun_mask)
433+
def round_and_trun_mask(chunk_trun_bits, clamp_acc_to_dl16):
434+
"""
435+
Rounding and LSB truncation masks only need to be generated once.
436+
These mask will be applied on "inner" accumulator, which is alway FP32 (e8m23). We may truncate
437+
up to 23b for mantissa. If DL16/DL8, pay attention to exponent bias.
438+
Examples: 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
439+
8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080
440+
"""
441+
if clamp_acc_to_dl16:
442+
# DL16 is e6m9, hence, truncate 23 - 9 = 14 bits
443+
chunk_trun_bits = 14
444+
round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0
445+
trun_mask = ~tl.cast((1 << chunk_trun_bits) - 1, tl.uint32)
446+
return round_bit, trun_mask
446447

447448

448449
@triton.jit
449-
def fp32_clamp_to_dl16(x):
450-
"""clamp FP32 (1-8-23) TENSOR x to DL16 (1-6-9) range."""
451-
# 1. rounding: add round bit, zero out last 13 bits, back to float
452-
x = libdevice.float_as_uint(x)
453-
round_bit = 1 << (23 - 9 - 1)
454-
mask_13x0 = ~tl.cast((1 << 13) - 1, tl.uint32)
455-
x = libdevice.uint_as_float((x + round_bit) & mask_13x0)
456-
457-
# 2. clamp to min/max:
458-
# max = 2^32 * 1.(1111 1111 0)_base2 => 2^32*1.(1111 1111 1) will become inf
459-
# (32 + 127) << 23 | (0xFF8 << (23 - 12)) in FP32 is 8581545984.0
460-
# min = 2^-31 * 1.(0000 0000 1)_base2 => set to 0 for those smaller than this
461-
# (-31 + 127) << 23 | (1 << (23 - 9)) in FP32 is 4.665707820095122e-10
462-
dl16_max = 8581545984.0
463-
dl16_min = 4.665707820095122e-10
464-
x = tl.where(x >= dl16_max, float("inf"), x)
465-
x = tl.where(x <= -dl16_max, float("-inf"), x)
466-
x = tl.where(tl.abs(x) < dl16_min, 0, x)
467-
450+
def round_and_trun(x, round_bit, trun_mask, clamp_acc_to_dl16):
451+
"""Round and truncate (usually for accumulator)."""
452+
x = libdevice.uint_as_float((libdevice.float_as_uint(x) + round_bit) & trun_mask)
453+
454+
if clamp_acc_to_dl16:
455+
# clamp to DL16 min/max:
456+
# max = 2^32 * 1.(1111 1111 0)_base2 = 2^32*(2 - 2^-9) = 8581545984.0
457+
# greater than this will become +inf (or -inf)
458+
# min = 2^-31 * 1.(0000 0000 1)_base2 = 2^-31*(1 + 2^-9)> = 4.665707820095122e-10
459+
# smaller than this will become 0
460+
dl16_max = 8581545984.0
461+
dl16_min = 4.665707820095122e-10
462+
x = tl.where(x >= dl16_max, float("inf"), x)
463+
x = tl.where(x <= -dl16_max, float("-inf"), x)
464+
x = tl.where(tl.abs(x) < dl16_min, 0, x)
468465
return x
469466

470467

468+
# @triton.jit
469+
# def fp32_clamp_to_dl16(x):
470+
# """clamp FP32 (1-8-23) TENSOR x to DL16 (1-6-9) range."""
471+
# # 1. rounding: add round bit, zero out last 13 bits, back to float
472+
# x = libdevice.float_as_uint(x)
473+
# round_bit = 1 << (23 - 9 - 1)
474+
# mask_13x0 = ~tl.cast((1 << 13) - 1, tl.uint32)
475+
# x = libdevice.uint_as_float((x + round_bit) & mask_13x0)
476+
477+
# # 2. clamp to min/max:
478+
# # max = 2^32 * 1.(1111 1111 0)_base2 => 2^32*1.(1111 1111 1) will become inf
479+
# # (32 + 127) << 23 | (0xFF8 << (23 - 12)) in FP32 is 8581545984.0
480+
# # min = 2^-31 * 1.(0000 0000 1)_base2 => set to 0 for those smaller than this
481+
# # (-31 + 127) << 23 | (1 << (23 - 9)) in FP32 is 4.665707820095122e-10
482+
# dl16_max = 8581545984.0
483+
# dl16_min = 4.665707820095122e-10
484+
# x = tl.where(x >= dl16_max, float("inf"), x)
485+
# x = tl.where(x <= -dl16_max, float("-inf"), x)
486+
# x = tl.where(tl.abs(x) < dl16_min, 0, x)
487+
488+
# return x
489+
490+
471491
def tl_matmul_chunk_truncate(
472492
a,
473493
b,

fms_mo/dq.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,18 +161,18 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
161161
# config layers to skip, smooth scale
162162
config_quantize_smooth_layers(qcfg)
163163

164+
use_dynamo = True
165+
# use dynamo as default unless really needed, False -> fallback to TorchScript tracing
164166
if any(x != 32 for x in attn_bits):
165167
logger.info("Quantize attention bmms or kvcache, will use dynamo for prep")
166168
use_layer_name_pattern_matching = False
167169
qcfg["qlayer_name_pattern"] = []
168170
assert (
169171
qcfg["qlayer_name_pattern"] == []
170172
), "ensure nothing in qlayer_name_pattern when use dynamo"
171-
use_dynamo = True
172173
else:
173174
logger.info("Attention bmms will not be quantized.")
174175
use_layer_name_pattern_matching = True
175-
use_dynamo = False
176176

177177
qcfg["seq_len"] = block_size
178178
qcfg["model"] = model_args.model_name_or_path
@@ -271,6 +271,33 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
271271
clamp_acc_to_dl16=False, # fms_mo_args.aiu_sim_triton == "fp8"
272272
# layer_to_exclude=["lm_head",]
273273
)
274+
# [CL] -------- record W, A, qW, qA with hooks ----------------
275+
# from fms_mo.modules.linear import QLinear, QLinearINT8Deploy
276+
# from fms_mo.quant.ptq import HookRecPostQuantInOut
277+
# cache_dict = {}
278+
# hook_handles = []
279+
# for n, m in model.named_modules():
280+
# if not isinstance(m, (QLinear, QLinearINT8Deploy, torch.nn.Linear)):
281+
# continue
282+
283+
# m.mod_name = n
284+
# hook_handles.append(
285+
# m.register_forward_hook( HookRecPostQuantInOut(cache_dict, n))
286+
# )
287+
288+
# data_mb = next(iter(eval_dataloader))
289+
# with torch.no_grad():
290+
# model(**data_mb)
291+
292+
# for h in hook_handles:
293+
# h.remove()
294+
295+
# torch.save(
296+
# cache_dict,
297+
# f"roberta_sqv2_data_dump_{qcfg['qa_mode']}_{qcfg['qw_mode']}_chunk64_lsb{args.aiu_int_lsb_trun}_dq.pt"
298+
# )
299+
# return
300+
274301
if fms_mo_args.eval_ppl:
275302
path_test = Path(data_args.test_data_path)
276303
arrow_files = list(path_test.glob("*.arrow"))

fms_mo/modules/linear.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1926,6 +1926,7 @@ def forward(
19261926
ctx.chunk_size = chunk_size
19271927
ctx.fp8_dyn = fp8_dyn
19281928
ctx.clamp_acc_to_dl16 = clamp_acc_to_dl16
1929+
ctx.dl8_min = 0.0087890625
19291930

19301931
if fp8_dyn:
19311932
# use Q/dQ simulation for now, meaning still compute in fp16/bf16
@@ -1943,6 +1944,11 @@ def forward(
19431944

19441945
x = (x / x_scale).to(torch.float8_e4m3fn).to(org_dtype) * x_scale
19451946
weight = (weight / w_scale).to(torch.float8_e4m3fn).to(org_dtype) * w_scale
1947+
if clamp_acc_to_dl16:
1948+
# NOTE For DL8@DL8 acc in DL16, as DL8 doesn't support subnorm numbers like PyTorch
1949+
# (whose real min for e4m3fn is 2^-9), need to flush subnorm numbers to 0
1950+
x.masked_fill_(x < ctx.dl8_min, 0)
1951+
weight.masked_fill_(weight < ctx.dl8_min, 0)
19461952

19471953
# triton kernel assumes 2D inputs and cast the return to input.dtype
19481954
output = tl_matmul(
@@ -1983,6 +1989,11 @@ def backward(ctx, grad_output):
19831989
grad_output_2D = (grad_output_2D / grad_out_scale).to(torch.float8_e5m2).to(
19841990
grad_output.dtype
19851991
) * grad_out_scale
1992+
if ctx.clamp_acc_to_dl16:
1993+
# flush subnorm numbers to 0 as DL8 doesn't support it
1994+
x.masked_fill_(x < ctx.dl8_min, 0)
1995+
weight.masked_fill_(weight < ctx.dl8_min, 0)
1996+
grad_output_2D.masked_fill_(grad_output_2D < ctx.dl8_min, 0)
19861997

19871998
# Compute grad_weight, shape = [out, in]
19881999
# NOTE: this triton kernel requires A matrix to be contiguous

fms_mo/quant/ptq.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
# Local
4343
from fms_mo.modules import QBmm, QLinear
4444
from fms_mo.modules.conv import QConv2dPTQv2
45+
from fms_mo.modules.linear import LinearFPxAcc, QLinearINT8Deploy
4546
from fms_mo.quant.quantizers import (
4647
AdaRoundQuantizer,
4748
Qdynamic,
@@ -481,8 +482,118 @@ def __call__(self, mod, inputs, *args, **_kwargs):
481482
assert not self.stop_after_rec
482483

483484

484-
# this hook is meant for ptq_loss_func == 'fisher_diag' and to temp hold the "Q_out" of the module
485+
class HookRecPostQuantInOut(torch.nn.Module):
486+
"""Another simplified hook to check post-quantized input/output, e.g. within +-127 for INT8."""
487+
488+
def __init__(self, cache_dict={}, mod_name=None):
489+
super().__init__()
490+
self.cache_dict = cache_dict
491+
self.mod_name = mod_name
492+
name_split = mod_name.split(".")
493+
self.lay_idx = int(name_split[3])
494+
self.lay_key = name_split[6]
495+
496+
self.cache_dev = "cpu"
497+
# prepare empty dict for later use
498+
self.cache_dict[mod_name] = {}
499+
self.fwd_mapping = {
500+
LinearFPxAcc: self.call_func_for_fpxacc,
501+
QLinear: self.call_func_for_qlinear,
502+
QLinearINT8Deploy: self.call_func_for_qlinear_int,
503+
torch.nn.Linear: self.call_func_for_nnlinear,
504+
}
505+
506+
def call_func_for_fpxacc(self, mod, inputs, outputs, **_kwargs):
507+
raise NotImplementedError
508+
509+
def call_func_for_qlinear(self, mod, inputs, outputs, **_kwargs):
510+
lay_idx = self.lay_idx
511+
lay_key = self.lay_key
512+
mod_name = self.mod_name
513+
cache_dict = self.cache_dict
514+
515+
act_max = inputs[0].abs().amax(dim=[d for d in range(len(inputs[0].shape) - 1)])
516+
# mod.smoothq_act_scale
517+
w_max = mod.weight.abs().max(dim=0, keepdim=True)[0].clamp(min=1e-5)
518+
is_smq_layer = not torch.all(act_max == 0).item()
519+
# smoothQ scale = smoothq_act_scale**alpha / weight_scale**(1.0 - alpha)
520+
# smoothq_scale = mod.get_smoothq_scale(inputs[0])
521+
smoothq_scale = getattr(mod, "smq_scale", 1.0)
522+
# "smq_scale" only available in QLin_INT8
523+
524+
with torch.no_grad():
525+
smoothed_inp = inputs[0] / smoothq_scale
526+
smoothed_w = mod.weight * smoothq_scale
527+
528+
# this is assuming pertokenmax quantizer, NOTE calc quant scale after smoothing
529+
absmax = smoothed_inp.abs().max(dim=-1, keepdim=True)[0]
530+
qa_scale = absmax.clamp(min=1e-5) / 127
531+
qinput = torch.round(smoothed_inp / qa_scale).clamp(-127, 127)
532+
# should clamp to -128?
533+
if mod.qa_mode == "pertokenmax":
534+
# doesnt implement dequant=False yet, do it manually
535+
cva = mod.quantize_feature.clip_val
536+
qa_scale = cva.clamp(min=1e-5).div(127)
537+
qinput = smoothed_inp.div(qa_scale).round()
538+
else:
539+
mod.quantize_feature.dequantize = False
540+
qinput = mod.quantize_feature(smoothed_inp)
541+
mod.quantize_feature.dequantize = True
542+
543+
# also record quantized, smoothed W in INT8, calc both maxperCh and SAWBperCh
544+
cvw = mod.quantize_weight.clip_val
545+
scale_w = cvw / 127
546+
mod.quantize_weight.dequantize = False
547+
qw = mod.quantize_weight(smoothed_w)
548+
mod.quantize_weight.dequantize = True
549+
550+
# inputs is a tuple, QLinear only has 1 valid input
551+
cache_dict[mod_name]["input"] = inputs[0].to(self.cache_dev)
552+
cache_dict[mod_name]["cva"] = cva.to(self.cache_dev)
553+
cache_dict[mod_name]["cvw"] = cvw.to(self.cache_dev)
554+
cache_dict[mod_name]["smoothed_input"] = smoothed_inp.to(self.cache_dev)
555+
cache_dict[mod_name]["smoothed_weight"] = smoothed_w.to(self.cache_dev)
556+
cache_dict[mod_name]["qinput"] = qinput.to(self.cache_dev)
557+
# NOTE in INT8, *scales if need dQ
558+
cache_dict[mod_name]["qweight"] = qw.to(self.cache_dev)
559+
# torch.round(smoothed_w.T/scale_w).clamp(-127, 127).to(self.cache_dev)
560+
# cache_dict[mod_name]["qoutput"] = outputs.to(self.cache_dev)
561+
562+
def call_func_for_qlinear_int(self, mod, inputs, outputs, **_kwargs):
563+
smoothq_scale = getattr(mod, "smq_scale", 1.0)
564+
mod_name = self.mod_name
565+
cache_dict = self.cache_dict
566+
with torch.no_grad():
567+
if mod.useDynMaxQfunc in [-1, -2]:
568+
qinput = mod.qa_dynamic_max_qfunc(inputs[0])
569+
elif mod.use_fake_zero_shift:
570+
qinput = mod.qa_dyn_max_fake_zero_shift(inputs[0])
571+
elif mod.usePTnativeQfunc:
572+
qinput = mod.qa_raw_qfunc(inputs[0])
573+
else:
574+
qinput = mod.qa_fmo_mo_qfunc(inputs[0])
575+
576+
# inputs is a tuple, QLinear only has 1 valid input
577+
cache_dict[mod_name]["input"] = inputs[0].to(self.cache_dev)
578+
cache_dict[mod_name]["cva"] = mod.cvs[0].to(self.cache_dev)
579+
cache_dict[mod_name]["cvw"] = mod.cvs[2].to(self.cache_dev)
580+
cache_dict[mod_name]["qinput"] = qinput.to(self.cache_dev)
581+
cache_dict[mod_name]["qweight"] = mod.weight.to(self.cache_dev)
582+
583+
def call_func_for_nnlinear(self, mod, inputs, outputs, **_kwargs):
584+
mod_name = self.mod_name
585+
cache_dict = self.cache_dict
586+
cache_dict[mod_name]["input"] = inputs[0].to(self.cache_dev)
587+
cache_dict[mod_name]["weight"] = mod.weight.to(self.cache_dev)
588+
589+
def __call__(self, mod, inputs, outputs, **_kwargs):
590+
self.fwd_mapping[type(mod)](mod, inputs, outputs, **_kwargs)
591+
592+
485593
class PTQHookRecQOut(nn.Module):
594+
"""This hook is for ptq_loss_func == 'fisher_diag' and will temporarily hold the "Q_out" of the
595+
module"""
596+
486597
def __init__(self, qcfg):
487598
super().__init__()
488599
self.qcfg = qcfg

fms_mo/training_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ class FMSMOArguments(TypeChecker):
192192
default=2048, metadata={"help": "input sequence length after tokenization"}
193193
)
194194
eval_ppl: bool = field(default=False)
195-
aiu_sim_triton: str = field(
195+
aiu_sim_triton: Optional[str] = field(
196196
default=None,
197197
metadata={
198198
"help": (

0 commit comments

Comments
 (0)