Skip to content

Commit 553c7a6

Browse files
1. temp enables Qmax.dequant=False, bmgroth will officially enable it later, 2. add util func to lower qmodel to triton kernel, 3. additional fix for dq, e.g. torch.load
Signed-off-by: cliu-us <[email protected]>
1 parent 362d521 commit 553c7a6

File tree

5 files changed

+103
-25
lines changed

5 files changed

+103
-25
lines changed

fms_mo/custom_ext_kernels/utils.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,64 @@ def lower_qmodel_cutlass(
859859
return mod
860860

861861

862+
def lower_qmodel_triton(
863+
model: torch.nn.Module,
864+
use_dyn_max_act=False,
865+
max_acc_bits=32,
866+
num_lsb_to_truncate=0,
867+
chunk_size=32,
868+
):
869+
"""
870+
Examplar GPU lowering function using triton. Only swap Qlinears in transformers, nothing else.
871+
Triton kernel can be used to:
872+
1. test INT8 or FP8 HW performance (kernel is not optimized)
873+
2. simulate MSB/LSB truncation effect
874+
875+
Args:
876+
model: nn.Module. should be a fms_mo Qmodel, will do inplace layer swapping, no deepcopy
877+
use_dyn_max_act: bool or int, can be False or -1 for per-token, or -2 for perCh. will use
878+
dynamic max quantizer for activation if not False.
879+
max_acc_bits: max bits for accumulator, typically FP32 for all FP matmuls and INT32 for all
880+
INT matmuls. But some HW could use fewer bits to trade-off power
881+
efficiency at the expense of higher chance of accumulation "overflow".
882+
For example, an INT24 accumulator can only hold values ranged from -2^23 to
883+
2^23 -1, as opposed to typical range -2^31 to -2^31 -1.
884+
num_lsb_to_truncate: number of bits to truncate from LSB side. For example, given fp32 is
885+
s1e8m23, if we choose to truncate 13 mantissa bits from right most side,
886+
i.e. LSB, the resulting number will be s1e8m10, which is TF32.
887+
chunk_size: given a matmul of (m, k) @ (k, n), the inner product will be "accumulated" along
888+
k-dim. Since the entire matrix will be partitioned into smaller tiles when being
889+
computed, accumulator will only add a certain num of elements in one shot. This
890+
"chunk size" in k-dim will affect the overflow/underflow of accumulator.
891+
"""
892+
# Third Party
893+
from torch.ao.quantization.utils import _parent_name
894+
895+
# Local
896+
from fms_mo.modules.linear import QLinear, QLinearINT8Deploy
897+
898+
for name, m in model.named_modules():
899+
if not isinstance(m, QLinear):
900+
continue
901+
parent_name, module_name = _parent_name(name)
902+
parent_mod = model.get_submodule(parent_name)
903+
qmod = getattr(parent_mod, module_name)
904+
setattr(
905+
parent_mod,
906+
module_name,
907+
QLinearINT8Deploy.from_fms_mo(
908+
qmod,
909+
use_int_kernel="triton",
910+
use_dynamic_max_act_Qfunc=use_dyn_max_act,
911+
max_acc_bits=max_acc_bits,
912+
truncate_lsb=num_lsb_to_truncate,
913+
chunk_size=chunk_size,
914+
),
915+
)
916+
917+
logger.info(f"\nModel lowering with triton kernel is done.\n{model}")
918+
919+
862920
### -------------------------------------------------------------
863921
# GPTQ tensor packing functions for Exllama kernel
864922
### -------------------------------------------------------------

fms_mo/dq.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
172172

173173
qcfg["seq_len"] = block_size
174174
qcfg["model"] = model_args.model_name_or_path
175-
qcfg["smoothq"] = True
175+
qcfg["smoothq"] = fms_mo_args.smoothq_alpha != -1
176176
qcfg["plotsvg"] = False
177177

178178
calibration_dataset = load_from_disk(data_args.training_data_path)
@@ -217,9 +217,10 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
217217
save_fname="dq",
218218
)
219219
logger.info(f"Quantized model {model}")
220-
logger.info("Starting to apply smooth scale")
221-
dq_llm(model, act_scales, qcfg)
222-
logger.info("Finished applying smooth scale")
220+
if qcfg["smoothq"]:
221+
logger.info("Starting to apply smooth scale")
222+
dq_llm(model, act_scales, qcfg)
223+
logger.info("Finished applying smooth scale")
223224
logger.info("==" * 20)
224225
if qcfg["qmodel_calibration_new"] > 0:
225226
logger.info("Starting to calibrate activation clip_val")
@@ -249,7 +250,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
249250
test_dataset = load_from_disk(data_args.test_data_path)
250251
test_dataset = test_dataset.with_format("torch")
251252
elif len(pt_files) > 0:
252-
test_dataset = torch.load(pt_files[0])
253+
test_dataset = torch.load(pt_files[0], weights_only=False)
253254

254255
logger.info(f"Model for evaluation: {model}")
255256
if qcfg["large_model"]:
@@ -258,7 +259,8 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
258259
model.to(torch.device("cuda:0"))
259260
n_samples = int(test_dataset.input_ids.shape[1] / block_size)
260261
evaluator = Evaluator(test_dataset, "cuda", n_samples=n_samples)
261-
ppl = evaluator.evaluate(model, block_size=block_size)
262+
with patch_torch_bmm(qcfg):
263+
ppl = evaluator.evaluate(model, block_size=block_size)
262264
logger.info(f"Model perplexity: {ppl}")
263265
logger.info("-" * 50)
264266
logger.info("Finished evaluation")

fms_mo/modules/linear.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
# Local
3030
from fms_mo.custom_ext_kernels.utils import pack_vectorized
3131
from fms_mo.quant.quantizers import (
32+
SAWB,
3233
HardPrune,
3334
Qbypass,
3435
Qdynamic,
@@ -751,7 +752,7 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
751752
fms_mo_qlinear.in_features,
752753
fms_mo_qlinear.out_features,
753754
bias=fms_mo_qlinear.bias is not None,
754-
device=tar_dev,
755+
device="meta", # init on tar_dev is unnecessary
755756
)
756757
# Make sure to register an Op for integer matmul, could be real INT matmul or emulation
757758
qcfg = getattr(fms_mo_qlinear, "qcfg", {})
@@ -777,39 +778,34 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
777778
with torch.no_grad():
778779
Qa = fms_mo_qlinear.quantize_feature
779780
Qw = fms_mo_qlinear.quantize_weight
780-
w_cv = Qw.clip_val.item()
781+
w_cv = Qw.clip_val
781782
if qlin_int.useDynMaxQfunc in [-1, -2]: # [-1, -2] indicates reduce_dim
782783
# dynamic Qmax has no clipvals, reg fake ones, won't be used in real calc
783784
Qa.register_buffer("clip_val", torch.tensor(8.0, device=tar_dev))
784785
Qa.register_buffer("clip_valn", torch.tensor(-8.0, device=tar_dev))
785-
a_cv, a_cvn = Qa.clip_val.item(), Qa.clip_valn.item()
786-
# Store original cv_a and cv_w (in python floats, not tensors), and sq scales
787-
# for later use (probably not necessary)
788-
qlin_int.cvs = [a_cv, a_cvn, w_cv]
789-
# NOTE: Keep w transposed to prevent confusion
790-
Qw.dequantize = False
791-
# trigger Qw.clipval re-calc for SAWB (if needed)
792-
w_int8 = Qw(fms_mo_qlinear.weight.float())
793-
qlin_int.weight = nn.Parameter(
794-
w_int8.to(torch.int8), requires_grad=False
795-
) # NOTE: may need INT W stored as FP in some cases
786+
a_cv = Qa.clip_val
787+
a_cvn = Qa.clip_valn
788+
# Store original cv_a and cv_w in python floats (instead of tensors) will be more
789+
# accurate, but not compatible for per-ch and per-token.
790+
qlin_int.cvs = [a_cv, a_cvn, w_cv] # TODO remove the need of this.
796791

792+
# may need to trigger Qw.clipval re-calc for SAWB here, (if needed?)
797793
if qlin_int.useDynMaxQfunc in [-1, -2]:
798794
input_scale = torch.tensor(1.0, device=tar_dev)
799795
input_zero_point = torch.tensor(128, dtype=torch.int, device=tar_dev)
800-
w_scale = torch.tensor([w_cv * 2 / w_levels], device=tar_dev)
796+
w_scale = w_cv * 2 / w_levels
801797
elif qlin_int.usePTnativeQfunc:
802798
input_scale = torch.tensor([(a_cv - a_cvn) / a_levels], device=tar_dev)
803799
input_zero_point = torch.round(-a_cvn / input_scale).to(torch.int)
804-
w_scale = torch.tensor([w_cv * 2 / w_levels], device=tar_dev)
800+
w_scale = w_cv * 2 / w_levels
805801
else:
806802
# fms_mo formula is a bit different from conventional PT formula
807803
quant_scale = a_levels / torch.tensor([a_cv - a_cvn], device=tar_dev)
808804
quant_stepsize = 1.0 / quant_scale
809805
quant_zero_point = torch.round(a_cvn * quant_scale)
810806
input_scale = quant_stepsize
811807
input_zero_point = -quant_zero_point
812-
quant_w_scale = w_levels / torch.tensor([w_cv * 2], device=tar_dev)
808+
quant_w_scale = w_levels / (w_cv * 2)
813809
w_scale = 1.0 / quant_w_scale
814810
qlin_int.register_buffer("quant_scale", quant_scale)
815811
qlin_int.register_buffer("quant_stepsize", quant_stepsize)
@@ -821,6 +817,21 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
821817
qlin_int.register_buffer("w_scale", w_scale)
822818
qlin_int.register_buffer("w_zp", w_zp)
823819

820+
# NOTE:
821+
# 1. Keep W transposed to prevent confusion, hence (W.t()/scale).t()
822+
# 2. only a few quantizer have .dequantize working correctly
823+
if isinstance(Qw, SAWB):
824+
Qw.dequantize = False
825+
w_int8 = Qw(fms_mo_qlinear.weight.float())
826+
else:
827+
w_int8 = (
828+
torch.round(fms_mo_qlinear.weight.t() / w_scale)
829+
.clamp(-w_levels / 2, w_levels / 2)
830+
.t()
831+
)
832+
833+
qlin_int.weight = nn.Parameter(w_int8.to(torch.int8), requires_grad=False)
834+
824835
corr_term = (
825836
(input_zero_point - 128 + qlin_int.useSymAct)
826837
* (w_int8.sum(dim=1))
@@ -836,8 +847,11 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
836847
(fms_mo_qlinear.bias - corr_term).to(fms_mo_w_dtype),
837848
requires_grad=False,
838849
)
850+
839851
qlin_int.org_model_has_bias = True
840852
else:
853+
delattr(qlin_int, "bias")
854+
# even if bias is None, reg_buffer() is still unhappy about it
841855
qlin_int.register_buffer("bias", -corr_term.to(fms_mo_w_dtype))
842856
qlin_int.org_model_has_bias = False
843857

fms_mo/quant/quantizers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3183,9 +3183,7 @@ class QmaxPerChSTE(torch.autograd.Function):
31833183
"""
31843184

31853185
@staticmethod
3186-
def forward(
3187-
ctx, input_tensor, num_bits, _dequantize, inplace, cv, _cvn, align_zero
3188-
):
3186+
def forward(ctx, input_tensor, num_bits, dequantize, inplace, cv, _cvn, align_zero):
31893187
if inplace:
31903188
ctx.mark_dirty(input_tensor)
31913189
scale = (2**num_bits - 2) if align_zero else (2**num_bits - 1)
@@ -3206,6 +3204,9 @@ def forward(
32063204
quant_min=int_l,
32073205
quant_max=int_u,
32083206
).to(input_tensor.dtype)
3207+
3208+
if not dequantize:
3209+
return (output.t() / scale).t()
32093210
return output
32103211

32113212
@staticmethod

fms_mo/training_args.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,9 @@ class FMSMOArguments(TypeChecker):
173173
default=2048, metadata={"help": "input sequence length after tokenization"}
174174
)
175175
eval_ppl: bool = field(default=False)
176+
aiu_sim_triton: bool = field(
177+
default=False, metadata={"help": ("AIU simulation with triton kernel")}
178+
)
176179

177180

178181
@dataclass

0 commit comments

Comments
 (0)