Skip to content

Commit 8c7a4e8

Browse files
add dynamic act quantizer option (something like pertokenmax) for QLinearINT8Deploy
Signed-off-by: cliu-us <[email protected]>
1 parent 8b570a7 commit 8c7a4e8

File tree

2 files changed

+62
-43
lines changed

2 files changed

+62
-43
lines changed

fms_mo/custom_ext_kernels/utils.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -633,14 +633,17 @@ def exv2_i4f16_fxinputs_abstract(
633633

634634

635635
def imatmul_ops_reg(
636-
useCUTLASS=True, mm_func=torch.matmul, AB_dtype=torch.float, D_dtype=torch.float
636+
useINTkernel="triton",
637+
mm_func=torch.matmul,
638+
AB_dtype=torch.float,
639+
D_dtype=torch.float,
637640
):
638641
"""This function will register a dummy Q_imatmul Op for better "graph representation".
639642
Args:
640-
useCUTLASS: bool. choose to use a) real INT matmul using cutlass kernel or b) "simulated"
641-
imatmul using torch.matmul.
643+
useINTkernel: str|bool. ["cutlass", "triton", False]. choose to use a) real INT matmul, e.g.
644+
cutlass or triton kernel or b) "simulated" imatmul using torch.matmul.
642645
For b), could use D_dtype to select fp16 or fp32 accumulation
643-
mm_func: matmul func to be used when useCUTLASS is True, should be a real callable kernel
646+
mm_func: matmul func to be used when useINTkernel is True, should be a real callable kernel
644647
from cutlass, but for debug purpose, could use torch.matmul as well.
645648
AB_dtype: datatype for input tensors
646649
D_dtype: datatype for accumulation and output tensor
@@ -697,10 +700,10 @@ def imatmul(m1, m2):
697700
tar_shape = tuple(m1.shape[:-1]) + (m2.shape[1],)
698701
m1 = m1.view(re_shape)
699702

700-
if useCUTLASS:
703+
if useINTkernel:
701704
assert (
702705
m1.dtype == torch.int8 and m2.dtype == torch.int8
703-
), "When using cutlass int matmul, inputs must be 2D INT8"
706+
), "When using int matmul, inputs must be 2D and INT8."
704707
return mm_func(m1, m2).reshape(tar_shape)
705708

706709
outf32_or_f16 = torch.empty(
@@ -759,7 +762,7 @@ def q_iaddmm_dq(bias, m1, m2, scale_i, zp_i, scale_w):
759762
assert m2.dtype == torch.int8, f"weight tensor is of incorrect dtype {m2.dtype}"
760763
m1 = torch.clamp((m1 / scale_i + zp_i - 128).round(), -128, 127).to(torch.int8)
761764

762-
if useCUTLASS:
765+
if useINTkernel:
763766
mm_i32 = mm_func(m1, m2)
764767
else:
765768
outf32_or_f16 = torch.empty(

fms_mo/modules/linear.py

Lines changed: 52 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,7 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
742742
for a_or_w in ["num_bits_feature", "num_bits_weight"]
743743
), "Please check nbits setting!"
744744

745-
target_device = kwargs.get(
745+
tar_dev = kwargs.get(
746746
"target_device",
747747
kwargs.get("device", next(fms_mo_qlinear.parameters()).device),
748748
)
@@ -751,14 +751,15 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
751751
fms_mo_qlinear.in_features,
752752
fms_mo_qlinear.out_features,
753753
bias=fms_mo_qlinear.bias is not None,
754-
device=target_device,
754+
device=tar_dev,
755755
)
756756
# Make sure to register an Op for integer matmul, could be real INT matmul or emulation
757757
qcfg = getattr(fms_mo_qlinear, "qcfg", {})
758758
qlin_int.use_int_kernel = kwargs.get(
759759
"use_int_kernel", qcfg.get("use_int_kernel", "cutlass")
760760
)
761761
qlin_int.usePTnativeQfunc = kwargs.get("use_PT_native_Qfunc", False)
762+
qlin_int.useDynMaxQfunc = kwargs.get("use_dynamic_max_act_Qfunc", False)
762763
qlin_int.max_acc_bits = kwargs.get("max_acc_bits", 32)
763764
qlin_int.accminmax = (
764765
-(1 << (qlin_int.max_acc_bits - 1)),
@@ -773,34 +774,48 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
773774
with torch.no_grad():
774775
Qa = fms_mo_qlinear.quantize_feature
775776
Qw = fms_mo_qlinear.quantize_weight
776-
a_cv, a_cvn = Qa.clip_val.item(), Qa.clip_valn.item()
777777
w_cv = Qw.clip_val.item()
778+
if qlin_int.useDynMaxQfunc in [-1, -2]: # [-1, -2] indicates reduce_dim
779+
# dynamic Qmax has no clipvals, reg fake ones, won't be used in real calc
780+
Qa.register_buffer("clip_val", torch.tensor(8.0, device=tar_dev))
781+
Qa.register_buffer("clip_valn", torch.tensor(-8.0, device=tar_dev))
782+
a_cv, a_cvn = Qa.clip_val.item(), Qa.clip_valn.item()
783+
# Store original cv_a and cv_w (in python floats, not tensors), and sq scales
784+
# for later use (probably not necessary)
785+
qlin_int.cvs = [a_cv, a_cvn, w_cv]
778786
# NOTE: Keep w transposed to prevent confusion
779787
Qw.dequantize = False
780-
w_int8 = Qw(
781-
fms_mo_qlinear.weight.float()
782-
) # Qw.clipval should have been updated after this
788+
# trigger Qw.clipval re-calc for SAWB (if needed)
789+
w_int8 = Qw(fms_mo_qlinear.weight.float())
783790
qlin_int.weight = nn.Parameter(
784791
w_int8.to(torch.int8), requires_grad=False
785792
) # NOTE: may need INT W stored as FP in some cases
786793

787-
if qlin_int.usePTnativeQfunc:
794+
if qlin_int.useDynMaxQfunc in [-1, -2]:
795+
input_scale = torch.tensor(1.0, device=tar_dev)
796+
input_zero_point = torch.tensor(128, dtype=torch.int, device=tar_dev)
797+
w_scale = torch.tensor(
798+
[w_cv * 2 / (2**qlin_int.nbits_w - 2)], device=tar_dev
799+
)
800+
elif qlin_int.usePTnativeQfunc:
788801
input_scale = torch.tensor(
789-
[(a_cv - a_cvn) / (2**qlin_int.nbits_a - 1)], device=target_device
802+
[(a_cv - a_cvn) / (2**qlin_int.nbits_a - 1)], device=tar_dev
790803
)
791804
input_zero_point = torch.round(-a_cvn / input_scale).to(torch.int)
792-
w_scale = torch.tensor([w_cv * 2 / (2**qlin_int.nbits_w - 2)])
805+
w_scale = torch.tensor(
806+
[w_cv * 2 / (2**qlin_int.nbits_w - 2)], device=tar_dev
807+
)
793808
else:
794809
# fms_mo formula is a bit different from conventional PT formula
795810
quant_scale = (2**qlin_int.nbits_a - 1) / torch.tensor(
796-
[a_cv - a_cvn], device=target_device
811+
[a_cv - a_cvn], device=tar_dev
797812
)
798813
quant_stepsize = 1.0 / quant_scale
799814
quant_zero_point = torch.round(a_cvn * quant_scale)
800815
input_scale = quant_stepsize
801816
input_zero_point = -quant_zero_point
802817
quant_w_scale = (2**qlin_int.nbits_a - 2) / torch.tensor(
803-
[w_cv * 2], device=target_device
818+
[w_cv * 2], device=tar_dev
804819
)
805820
w_scale = 1.0 / quant_w_scale
806821
qlin_int.register_buffer("quant_scale", quant_scale)
@@ -812,9 +827,6 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
812827
qlin_int.register_buffer("input_zp", input_zero_point)
813828
qlin_int.register_buffer("w_scale", w_scale)
814829
qlin_int.register_buffer("w_zp", w_zp)
815-
# Store original cv_a and cv_w (in python floats, not tensors), and sq scales
816-
# for later verification
817-
qlin_int.cvs = [Qa.clip_val.item(), Qa.clip_valn.item(), Qw.clip_val.item()]
818830

819831
corr_term = (
820832
(input_zero_point - 128)
@@ -836,17 +848,14 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
836848
qlin_int.register_buffer("bias", -corr_term.to(fms_mo_w_dtype))
837849
qlin_int.org_model_has_bias = False
838850

839-
qlin_int.register_buffer("Qa_clip_val", Qa.clip_val.detach())
840-
qlin_int.register_buffer(
841-
"Qa_clip_valn", Qa.clip_valn.detach()
842-
) # TODO: case for PACT?
843-
qlin_int.register_buffer(
844-
"Qw_clip_val", Qw.clip_val.detach()
845-
) # asym W quantizer may have clipvaln
851+
# redundant variables to be cleaned up
852+
# qlin_int.register_buffer("Qa_clip_val", Qa.clip_val.detach())
853+
# qlin_int.register_buffer("Qa_clip_valn", Qa.clip_valn.detach())
854+
# qlin_int.register_buffer("Qw_clip_val", Qw.clip_val.detach())
846855

847856
qlin_int.set_matmul_op()
848857

849-
return qlin_int.to(target_device)
858+
return qlin_int.to(tar_dev)
850859

851860
@classmethod
852861
def from_torch_iW(cls, nnlin_iW, prec, a_cv, a_cvn, w_cv, zero_shift, **kwargs):
@@ -988,25 +997,15 @@ def qa_raw_qfunc(self, x):
988997
"""
989998
Quantizes the input tensor x to 8-bit integer values using raw formula, slower if not
990999
torch.compiled
991-
992-
Args:
993-
x (Tensor): Input tensor to be quantized.
994-
995-
Returns:
996-
Tensor: Quantized tensor with values in the range [-128, 127].
9971000
"""
9981001
x = torch.clamp((x / self.input_scale + self.input_zp - 128).round(), -128, 127)
9991002
return x.to(torch.int8)
10001003

10011004
def qa_fmo_mo_qfunc(self, x):
10021005
"""
1003-
Quantizes the input tensor x to 8-bit integer values.
1004-
1005-
Args:
1006-
x (Tensor): Input tensor to be quantized.
1007-
1008-
Returns:
1009-
Tensor: Quantized tensor with values in the range [-128, 127].
1006+
Quantizes the input tensor x to 8-bit integer values. Note that old fms-mo formula clamps
1007+
before rounds, as opposed to typical torch formula that rounds before clamps.
1008+
(See qa_raw_qfunc() above.)
10101009
"""
10111010
x = (
10121011
torch.round(
@@ -1017,6 +1016,21 @@ def qa_fmo_mo_qfunc(self, x):
10171016
)
10181017
return x.to(torch.int8)
10191018

1019+
def qa_dynamic_max_qfunc(self, x):
1020+
"""
1021+
Symmetric dynamic quantizer, same as QDynMax, which allows per-token or per-channel.
1022+
This quantizer will not use self.input_scale but instead will update it every time.
1023+
NOTE
1024+
1. self.input_scale.shape should be (x.shape[-2], ) if reduce_dim == -1 and (, x.shape[-1])
1025+
for reduce_dim == -2.
1026+
2. input_scale should be be broadcasted correctly together with W_scale (e.g. if per-Ch) at
1027+
final output step, i.e. imm_out*(a_scale*w_scale)*...
1028+
"""
1029+
amax = x.abs().max(dim=self.useDynMaxQfunc, keepdim=True)[0]
1030+
levels = 2 ** (self.nbits_a - 1) - 1
1031+
self.input_scale = amax.clamp(min=1e-5).div(levels)
1032+
return torch.round(x / self.input_scale).to(torch.int8)
1033+
10201034
def iaddmm_int(self, bias, m1, m2):
10211035
"""
10221036
Performs integer matrix multiplication with optional addition of a bias term.
@@ -1034,7 +1048,9 @@ def iaddmm_int(self, bias, m1, m2):
10341048
The result of the integer matrix multiplication with the bias added.
10351049
"""
10361050

1037-
if self.usePTnativeQfunc:
1051+
if self.useDynMaxQfunc in [-1, -2]:
1052+
m1 = self.qa_dynamic_max_qfunc(m1)
1053+
elif self.usePTnativeQfunc:
10381054
m1 = self.qa_raw_qfunc(m1)
10391055
else:
10401056
m1 = self.qa_fmo_mo_qfunc(m1)

0 commit comments

Comments
 (0)