Skip to content

Commit 8e3f16a

Browse files
add dynamic symmetric activation option to QLinearINT8Deploy
Signed-off-by: cliu-us <[email protected]>
1 parent 8c7a4e8 commit 8e3f16a

File tree

2 files changed

+36
-31
lines changed

2 files changed

+36
-31
lines changed

fms_mo/custom_ext_kernels/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,7 @@ def imatmul(m1, m2):
700700
tar_shape = tuple(m1.shape[:-1]) + (m2.shape[1],)
701701
m1 = m1.view(re_shape)
702702

703-
if useINTkernel:
703+
if useINTkernel in ["triton", "cutlass"]:
704704
assert (
705705
m1.dtype == torch.int8 and m2.dtype == torch.int8
706706
), "When using int matmul, inputs must be 2D and INT8."

fms_mo/modules/linear.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,7 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
760760
)
761761
qlin_int.usePTnativeQfunc = kwargs.get("use_PT_native_Qfunc", False)
762762
qlin_int.useDynMaxQfunc = kwargs.get("use_dynamic_max_act_Qfunc", False)
763+
qlin_int.useSymAct = "sym" in fms_mo_qlinear.qa_mode
763764
qlin_int.max_acc_bits = kwargs.get("max_acc_bits", 32)
764765
qlin_int.accminmax = (
765766
-(1 << (qlin_int.max_acc_bits - 1)),
@@ -770,6 +771,8 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
770771
qlin_int.acc_dtype = torch.float16
771772
qlin_int.nbits_a = fms_mo_qlinear.num_bits_feature # only support INT8 for now
772773
qlin_int.nbits_w = fms_mo_qlinear.num_bits_weight
774+
w_levels = 2**qlin_int.nbits_w - 2
775+
a_levels = 2**qlin_int.nbits_a - 1 - qlin_int.useSymAct
773776

774777
with torch.no_grad():
775778
Qa = fms_mo_qlinear.quantize_feature
@@ -794,29 +797,19 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
794797
if qlin_int.useDynMaxQfunc in [-1, -2]:
795798
input_scale = torch.tensor(1.0, device=tar_dev)
796799
input_zero_point = torch.tensor(128, dtype=torch.int, device=tar_dev)
797-
w_scale = torch.tensor(
798-
[w_cv * 2 / (2**qlin_int.nbits_w - 2)], device=tar_dev
799-
)
800+
w_scale = torch.tensor([w_cv * 2 / w_levels], device=tar_dev)
800801
elif qlin_int.usePTnativeQfunc:
801-
input_scale = torch.tensor(
802-
[(a_cv - a_cvn) / (2**qlin_int.nbits_a - 1)], device=tar_dev
803-
)
802+
input_scale = torch.tensor([(a_cv - a_cvn) / a_levels], device=tar_dev)
804803
input_zero_point = torch.round(-a_cvn / input_scale).to(torch.int)
805-
w_scale = torch.tensor(
806-
[w_cv * 2 / (2**qlin_int.nbits_w - 2)], device=tar_dev
807-
)
804+
w_scale = torch.tensor([w_cv * 2 / w_levels], device=tar_dev)
808805
else:
809806
# fms_mo formula is a bit different from conventional PT formula
810-
quant_scale = (2**qlin_int.nbits_a - 1) / torch.tensor(
811-
[a_cv - a_cvn], device=tar_dev
812-
)
807+
quant_scale = a_levels / torch.tensor([a_cv - a_cvn], device=tar_dev)
813808
quant_stepsize = 1.0 / quant_scale
814809
quant_zero_point = torch.round(a_cvn * quant_scale)
815810
input_scale = quant_stepsize
816811
input_zero_point = -quant_zero_point
817-
quant_w_scale = (2**qlin_int.nbits_a - 2) / torch.tensor(
818-
[w_cv * 2], device=tar_dev
819-
)
812+
quant_w_scale = w_levels / torch.tensor([w_cv * 2], device=tar_dev)
820813
w_scale = 1.0 / quant_w_scale
821814
qlin_int.register_buffer("quant_scale", quant_scale)
822815
qlin_int.register_buffer("quant_stepsize", quant_stepsize)
@@ -829,7 +822,7 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
829822
qlin_int.register_buffer("w_zp", w_zp)
830823

831824
corr_term = (
832-
(input_zero_point - 128)
825+
(input_zero_point - 128 + qlin_int.useSymAct)
833826
* (w_int8.sum(dim=1))
834827
* w_scale.float()
835828
* input_scale.float()
@@ -975,7 +968,7 @@ def qa_pt_qfunc_wrapped(self, x):
975968
Tensor: Quantized tensor with values in the range [-128, 127].
976969
"""
977970
return torch.ops.fms_mo.q_per_t_sym(
978-
x.float(), self.input_scale, self.input_zp - 128
971+
x.float(), self.input_scale, self.input_zp - 128 + self.useSymAct
979972
)
980973

981974
def qa_pt_quant_func(self, x):
@@ -990,15 +983,22 @@ def qa_pt_quant_func(self, x):
990983
Tensor: Quantized tensor with values in the range [-128, 127].
991984
"""
992985
return torch.quantize_per_tensor(
993-
x.float(), self.input_scale, self.input_zp - 128, torch.qint8
986+
x.float(),
987+
self.input_scale,
988+
self.input_zp - 128 + self.useSymAct,
989+
torch.qint8,
994990
).int_repr()
995991

996992
def qa_raw_qfunc(self, x):
997993
"""
998994
Quantizes the input tensor x to 8-bit integer values using raw formula, slower if not
999995
torch.compiled
1000996
"""
1001-
x = torch.clamp((x / self.input_scale + self.input_zp - 128).round(), -128, 127)
997+
x = torch.clamp(
998+
(x / self.input_scale + self.input_zp - 128 + self.useSymAct).round(),
999+
-128,
1000+
127,
1001+
)
10021002
return x.to(torch.int8)
10031003

10041004
def qa_fmo_mo_qfunc(self, x):
@@ -1007,13 +1007,10 @@ def qa_fmo_mo_qfunc(self, x):
10071007
before rounds, as opposed to typical torch formula that rounds before clamps.
10081008
(See qa_raw_qfunc() above.)
10091009
"""
1010-
x = (
1011-
torch.round(
1012-
x.clamp(self.cvs[1], self.cvs[0]) / self.quant_stepsize
1013-
- self.quant_zero_point
1014-
)
1015-
- 128
1016-
)
1010+
x = torch.round(
1011+
x.clamp(self.cvs[1], self.cvs[0]) / self.quant_stepsize
1012+
- self.quant_zero_point
1013+
) - (128 - self.useSymAct)
10171014
return x.to(torch.int8)
10181015

10191016
def qa_dynamic_max_qfunc(self, x):
@@ -1060,7 +1057,9 @@ def iaddmm_int(self, bias, m1, m2):
10601057
Nchunk = len(idx)
10611058
idx.append(m1.shape[1])
10621059
accumulator = torch.zeros(
1063-
(m1.shape[0], m2.shape[1]), dtype=torch.float16, device=m1.device
1060+
(m1.shape[0], m2.shape[1]),
1061+
dtype=torch.int,
1062+
device=m1.device, # cast float16 if needed
10641063
)
10651064
trun_scale = 1
10661065
if self.truncate_lsb > 0:
@@ -1080,7 +1079,7 @@ def iaddmm_int(self, bias, m1, m2):
10801079
# could cast to smaller data type to further simulate HW behavior, for example,
10811080
# if HW truncates 8b from both sides of i32 accumulator, the remaining data can
10821081
# be cast to i16 to be more realistic. pay attention to overflow handling
1083-
accumulator += imm_out.to(torch.float16)
1082+
accumulator += imm_out # .to(torch.float16) if needed
10841083

10851084
return (
10861085
accumulator
@@ -1107,8 +1106,14 @@ def iaddmm_FP(self, bias, m1, m2):
11071106
Returns:
11081107
Tensor: the result of the matrix multiplication with addition of bias
11091108
"""
1110-
m2 = m2.to(m1.dtype)
1111-
return torch.addmm(bias, m1, m2)
1109+
if self.useDynMaxQfunc in [-1, -2]:
1110+
m1 = self.qa_dynamic_max_qfunc(m1)
1111+
elif self.usePTnativeQfunc:
1112+
m1 = self.qa_raw_qfunc(m1)
1113+
else:
1114+
m1 = self.qa_fmo_mo_qfunc(m1)
1115+
1116+
return torch.matmul(m1 * self.input_scale, m2 * self.w_scale) + bias
11121117

11131118
def set_matmul_op(self):
11141119
"""

0 commit comments

Comments
 (0)