Skip to content

Commit 1f9a1cc

Browse files
updated triton kernels and qlinear
Signed-off-by: cliu-us <[email protected]>
1 parent 011f184 commit 1f9a1cc

File tree

5 files changed

+145
-49
lines changed

5 files changed

+145
-49
lines changed

fms_mo/dq.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636

3737
# Local
3838
from fms_mo import qconfig_init, qmodel_prep
39+
from fms_mo.custom_ext_kernels.utils import (
40+
lower_qmodel_triton, # pylint: disable=unused-import
41+
)
3942
from fms_mo.fx.utils import model_size_Wb
4043
from fms_mo.quant.ptq import (
4144
calibration_llm_1GPU_v2,
@@ -256,6 +259,15 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
256259
model.save_pretrained(opt_args.output_dir, use_safetensors=True)
257260
tokenizer.save_pretrained(opt_args.output_dir)
258261

262+
if fms_mo_args.aiu_sim_triton:
263+
lower_qmodel_triton(
264+
model,
265+
use_dyn_max_act=-1,
266+
max_acc_bits=24,
267+
num_lsb_to_truncate=8,
268+
chunk_size=32,
269+
)
270+
259271
if fms_mo_args.eval_ppl:
260272
path_test = Path(data_args.test_data_path)
261273
arrow_files = list(path_test.glob("*.arrow"))

fms_mo/modules/bmm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def forward(self, m1, m2):
192192
torch.Tensor: Output tensor after quantized bmm.
193193
"""
194194
# pylint: disable = access-member-before-definition
195-
if self.calib_counter:
195+
if self.calib_counter > 0:
196196
with torch.no_grad():
197197
qm1 = self.quantize_calib_m1(m1)
198198
qm2 = self.quantize_calib_m2(m2)

fms_mo/modules/conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def forward(self, x):
270270
torch.Tensor: Output tensor of shape (batch_size, out_channels, out_height, out_width).
271271
"""
272272
# pylint: disable = access-member-before-definition
273-
if self.calib_counter:
273+
if self.calib_counter > 0:
274274
with torch.no_grad():
275275
qinput = self.quantize_calib_feature(x)
276276
qweight = self.quantize_calib_weight(self.weight)

fms_mo/modules/linear.py

Lines changed: 126 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def forward(self, x):
261261
scale = torch.tensor([1.0]).to(x.dtype).to(x.device)
262262

263263
# pylint: disable = access-member-before-definition
264-
if self.calib_counter:
264+
if self.calib_counter > 0:
265265
with torch.no_grad():
266266
qinput = self.quantize_calib_feature(x / scale)
267267
qweight = self.quantize_calib_weight(self.weight * scale)
@@ -733,6 +733,8 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
733733
chunk_size: some HW may have specific chunk size (BLOCK SIZE, especially in k-dim) for
734734
the reason to avoid overflow/underflow problem. This can be simulated using
735735
PyTorch (break a matmul into serial smaller matmuls, slow) or Triton kernel
736+
useDynMaxQfunc: [-1, -2] indicates reduce_dim, 0< val <64 indicates artificial
737+
zero-shift, False -> use normal static quantization.
736738
737739
Returns:
738740
A QLinearINT8Deploy object initialized with the weights and biases from the
@@ -761,7 +763,11 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
761763
)
762764
qlin_int.usePTnativeQfunc = kwargs.get("use_PT_native_Qfunc", False)
763765
qlin_int.useDynMaxQfunc = kwargs.get("use_dynamic_max_act_Qfunc", False)
764-
qlin_int.useSymAct = "sym" in fms_mo_qlinear.qa_mode
766+
qlin_int.useSymAct = (
767+
"sym" in fms_mo_qlinear.qa_mode
768+
or fms_mo_qlinear.qa_mode in ["pertokenmax", "max"]
769+
# these are the symmetric quantizers with no "sym" in their names
770+
)
765771
qlin_int.max_acc_bits = kwargs.get("max_acc_bits", 32)
766772
qlin_int.accminmax = (
767773
-(1 << (qlin_int.max_acc_bits - 1)),
@@ -778,26 +784,49 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
778784
with torch.no_grad():
779785
Qa = fms_mo_qlinear.quantize_feature
780786
Qw = fms_mo_qlinear.quantize_weight
787+
# if no calibration has been run before swapping, clipvals stored in Qw will be the
788+
# original one, e.g. per-tensor. If want to experiment with new quantizers, need to run
789+
# at least one fwd, which will update the clipvals.
790+
Qw(fms_mo_qlinear.weight)
781791
w_cv = Qw.clip_val
782-
if qlin_int.useDynMaxQfunc in [-1, -2]: # [-1, -2] indicates reduce_dim
783-
# dynamic Qmax has no clipvals, reg fake ones, won't be used in real calc
784-
Qa.register_buffer("clip_val", torch.tensor(8.0, device=tar_dev))
785-
Qa.register_buffer("clip_valn", torch.tensor(-8.0, device=tar_dev))
786-
a_cv = Qa.clip_val
787-
a_cvn = Qa.clip_valn
792+
a_cv = getattr(Qa, "clip_val", torch.tensor(8.0, device=tar_dev))
793+
a_cvn = getattr(Qa, "clip_valn", torch.tensor(-8.0, device=tar_dev))
788794
# Store original cv_a and cv_w in python floats (instead of tensors) will be more
789795
# accurate, but not compatible for per-ch and per-token.
790-
qlin_int.cvs = [a_cv, a_cvn, w_cv] # TODO remove the need of this.
796+
qlin_int.cvs = [a_cv, a_cvn, w_cv] # TODO remove the need of this?
797+
798+
# prepare smoothQuant scale, = (smQ_a_scale ^ alpha)/(smQ_w_scale ^ (1-alpha) )
799+
smq_scale = torch.tensor([1.0], device=tar_dev, dtype=fms_mo_w_dtype)
800+
if getattr(fms_mo_qlinear, "smoothq", False):
801+
smq_a_scale = fms_mo_qlinear.smoothq_act_scale
802+
smq_w_scale = (
803+
fms_mo_qlinear.weight.abs()
804+
.max(dim=0, keepdim=True)[0]
805+
.clamp(min=1e-5)
806+
)
807+
smq_alpha = fms_mo_qlinear.smoothq_alpha
808+
if torch.all(smq_a_scale != 0).item():
809+
smq_scale = (
810+
(smq_a_scale**smq_alpha / smq_w_scale ** (1.0 - smq_alpha))
811+
.clamp(min=1e-5)
812+
.to(smq_a_scale.dtype)
813+
)
791814

792-
# may need to trigger Qw.clipval re-calc for SAWB here, (if needed?)
815+
# could trigger Qw.clipval re-calc for SAWB here, if needed
816+
input_scale = torch.tensor(1.0, device=tar_dev)
817+
w_scale = w_cv * 2 / w_levels
818+
qlin_int.use_fake_zero_shift = False
793819
if qlin_int.useDynMaxQfunc in [-1, -2]:
794-
input_scale = torch.tensor(1.0, device=tar_dev)
795-
input_zero_point = torch.tensor(128, dtype=torch.int, device=tar_dev)
796-
w_scale = w_cv * 2 / w_levels
820+
input_zero_point = torch.tensor(
821+
128 - qlin_int.useSymAct, device=tar_dev
822+
)
823+
elif 0 < qlin_int.useDynMaxQfunc < 65:
824+
# introduce fake zero-shift, input_scale will be calc dynamically
825+
qlin_int.use_fake_zero_shift = True
826+
input_zero_point = torch.tensor(qlin_int.useDynMaxQfunc, device=tar_dev)
797827
elif qlin_int.usePTnativeQfunc:
798828
input_scale = torch.tensor([(a_cv - a_cvn) / a_levels], device=tar_dev)
799-
input_zero_point = torch.round(-a_cvn / input_scale).to(torch.int)
800-
w_scale = w_cv * 2 / w_levels
829+
input_zero_point = torch.round(-a_cvn / input_scale)
801830
else:
802831
# fms_mo formula is a bit different from conventional PT formula
803832
quant_scale = a_levels / torch.tensor([a_cv - a_cvn], device=tar_dev)
@@ -812,48 +841,70 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
812841
qlin_int.register_buffer("quant_zero_point", quant_zero_point)
813842
w_zp = torch.zeros_like(w_scale, dtype=torch.int)
814843

844+
input_zero_point = input_zero_point.to(torch.int) # note 2 in pre-compute
815845
qlin_int.register_buffer("input_scale", input_scale)
816846
qlin_int.register_buffer("input_zp", input_zero_point)
817847
qlin_int.register_buffer("w_scale", w_scale)
818848
qlin_int.register_buffer("w_zp", w_zp)
849+
qlin_int.register_buffer("smq_scale", smq_scale)
819850

820851
# NOTE:
821852
# 1. Keep W transposed to prevent confusion, hence (W.t()/scale).t()
822-
# 2. only a few quantizer have .dequantize working correctly
853+
# 2. only a few quantizer have .dequantize working correctly, e.g. SAWB
854+
# 3. smooth_quant factor is included in the W here, will also include it in the forward
823855
if isinstance(Qw, SAWB):
824856
Qw.dequantize = False
825-
w_int8 = Qw(fms_mo_qlinear.weight.float())
857+
w_int8 = Qw(fms_mo_qlinear.weight.float() * smq_scale)
826858
else:
827859
w_int8 = (
828-
torch.round(fms_mo_qlinear.weight.t() / w_scale)
860+
torch.round((fms_mo_qlinear.weight * smq_scale).t() / w_scale)
829861
.clamp(-w_levels / 2, w_levels / 2)
830862
.t()
831863
)
832-
864+
w_int8 = w_int8.to(
865+
torch.int
866+
) # stored as int32 as correction term needs sum()
833867
qlin_int.weight = nn.Parameter(w_int8.to(torch.int8), requires_grad=False)
834868

835-
corr_term = (
836-
(input_zero_point - 128 + qlin_int.useSymAct)
837-
* (w_int8.sum(dim=1))
838-
* w_scale.float()
839-
* input_scale.float()
840-
)
841-
# dim=1 because w_int is in [out,in], after sum shape=[out,], same as w_scale and bias.
842-
# (zp-128)*w_int8.sum(dim=1) can be >> fp16.max, use fp32 scales
843-
# to make sure dtype is large enough
844-
qlin_int.register_buffer("corr_term", corr_term.half()) # [DEBUG only]
845-
if fms_mo_qlinear.bias is not None:
846-
qlin_int.bias = nn.Parameter(
847-
(fms_mo_qlinear.bias - corr_term).to(fms_mo_w_dtype),
848-
requires_grad=False,
849-
)
869+
# Pre-compute the "correction term" for zero-shift for asym activation quantizers
870+
# NOTE:
871+
# 1. sym act should have corr_term=0, unless we want to introduce fake zero-shift
872+
# 2. sum to reduce dim=1 because w_int is in [out,in], after sum shape=[out,], same as
873+
# w_scale (per-Ch) and bias.
874+
# 3. calc INT part, i.e. (zp-128)*w_int8.sum(dim=1), first in INT32. because it can be
875+
# >> fp16.max (~65535 only) easily, make sure not to cast INT32 to FP16 during calc,
876+
# simply cast scales to FP32
877+
# 4. for the "fake zero-shift case", input_scale will be max/(127-fake_zero_shift)
878+
# instead of max/127, see qa_dyn_max_fake_zero_shift()
879+
# 5. Combine correction term into linear.bias for non-dynamic cases. For dyn quant,
880+
# input_scale is a placehold for now and will be calc'ed on the fly later.
881+
if qlin_int.useSymAct:
882+
corr_term_int = 0
883+
if qlin_int.use_fake_zero_shift:
884+
# one exception, fake zero-shift
885+
corr_term_int = input_zero_point * (w_int8.sum(dim=1))
886+
else:
887+
corr_term_int = (input_zero_point - 128) * (w_int8.sum(dim=1))
850888

851-
qlin_int.org_model_has_bias = True
889+
qlin_int.register_buffer(
890+
"corr_term", corr_term_int * w_scale.float() * input_scale.float()
891+
) # keep in FP32, cast at the end
892+
893+
qlin_int.org_model_has_bias = fms_mo_qlinear.bias is not None
894+
# Combine correction term into linear.bias when possible. NOTE the magnitude of these 2
895+
# terms could vary a lot. use fp32 in case of underflow and lose accuracy.
896+
if qlin_int.org_model_has_bias:
897+
new_bias = fms_mo_qlinear.bias.float() - qlin_int.corr_term
852898
else:
853-
delattr(qlin_int, "bias")
854-
# even if bias is None, reg_buffer() is still unhappy about it
855-
qlin_int.register_buffer("bias", -corr_term.to(fms_mo_w_dtype))
856-
qlin_int.org_model_has_bias = False
899+
new_bias = -qlin_int.corr_term
900+
901+
if qlin_int.use_fake_zero_shift:
902+
# dyn sym act but with fake zp, remove corr_term from bias
903+
new_bias += qlin_int.corr_term
904+
905+
delattr(qlin_int, "bias")
906+
# sometimes reg_buffer() is unhappy about existing bias
907+
qlin_int.register_buffer("bias", new_bias.to(fms_mo_w_dtype))
857908

858909
# redundant variables to be cleaned up
859910
# qlin_int.register_buffer("Qa_clip_val", Qa.clip_val.detach())
@@ -1039,9 +1090,25 @@ def qa_dynamic_max_qfunc(self, x):
10391090
"""
10401091
amax = x.abs().max(dim=self.useDynMaxQfunc, keepdim=True)[0]
10411092
levels = 2 ** (self.nbits_a - 1) - 1
1093+
self.cvs[0] = amax
1094+
self.cvs[1] = -amax
10421095
self.input_scale = amax.clamp(min=1e-5).div(levels)
10431096
return torch.round(x / self.input_scale).to(torch.int8)
10441097

1098+
def qa_dyn_max_fake_zero_shift(self, x):
1099+
"""Dynamic max quantizer with fake zero-shift in order to accommodate "zero-centered"
1100+
activations. "partial" correction term has been pre-computed in from_fms_mo() but still need
1101+
to multiply input_scale. (Assuming per-tensor, can shift left or right)
1102+
"""
1103+
amax = x.abs().max()
1104+
shift_dir = 1 if amax == x.max() else -1
1105+
levels = 2 ** (self.nbits_a - 1) - 1 - self.input_zp
1106+
self.cvs[0] = amax
1107+
self.cvs[1] = -amax
1108+
self.input_scale = amax.clamp(min=1e-5) / levels
1109+
xq = torch.round(x / self.input_scale) + self.input_zp
1110+
return xq.to(torch.int8)
1111+
10451112
def iaddmm_int(self, bias, m1, m2):
10461113
"""
10471114
Performs integer matrix multiplication with optional addition of a bias term.
@@ -1061,11 +1128,14 @@ def iaddmm_int(self, bias, m1, m2):
10611128

10621129
if self.useDynMaxQfunc in [-1, -2]:
10631130
m1 = self.qa_dynamic_max_qfunc(m1)
1131+
elif self.use_fake_zero_shift:
1132+
m1 = self.qa_dyn_max_fake_zero_shift(m1)
10641133
elif self.usePTnativeQfunc:
10651134
m1 = self.qa_raw_qfunc(m1)
10661135
else:
10671136
m1 = self.qa_fmo_mo_qfunc(m1)
10681137

1138+
# NOTE simulate chunk behavior in pytorch is serial and slow, use triton when possible
10691139
if m1.shape[1] > self.chunk_size and self.use_int_kernel != "triton":
10701140
idx = list(range(0, m1.shape[1], self.chunk_size))
10711141
Nchunk = len(idx)
@@ -1099,11 +1169,19 @@ def iaddmm_int(self, bias, m1, m2):
10991169
accumulator
11001170
* (trun_scale * self.input_scale * self.w_scale) # .to(torch.float16)
11011171
+ bias
1102-
).to(self.acc_dtype)
1103-
# The safest casting, i32 -> f32
1172+
).to(self.acc_dtype) # safest casting would be i32 -> f32
1173+
11041174
imm_out = torch.ops.fms_mo.imatmul(m1, m2)
1175+
1176+
updated_bias = bias
1177+
if self.use_fake_zero_shift:
1178+
# Do NOT change the stored self.corr_term and self.bias
1179+
updated_bias = bias - self.input_scale * self.corr_term
1180+
1181+
# cast to fp16 could be modified based on real HW behavior/design
11051182
return (
1106-
imm_out.float() * (self.input_scale * self.w_scale).to(torch.float16) + bias
1183+
imm_out.float() * (self.input_scale * self.w_scale).to(torch.float16)
1184+
+ updated_bias
11071185
).to(self.acc_dtype)
11081186

11091187
def iaddmm_FP(self, bias, m1, m2):
@@ -1247,9 +1325,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
12471325
self.weight.shape[0],
12481326
) # W.shape=[out,in]
12491327

1250-
x = self.iaddmm(self.bias, x.view(re_shape), self.weight.t()).reshape(
1251-
tar_shape
1252-
)
1328+
if torch.all(self.smq_scale != 1).item():
1329+
x = x.view(re_shape) / self.smq_scale
1330+
else:
1331+
x = x.view(re_shape)
1332+
1333+
x = self.iaddmm(self.bias, x, self.weight.t()).reshape(tar_shape)
12531334

12541335
return x.to(org_dtype)
12551336

fms_mo/quant/quantizers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3476,11 +3476,14 @@ def __init__(self, num_bits):
34763476
"""
34773477
super().__init__()
34783478
self.num_bits = num_bits
3479+
self.register_buffer("clip_val", torch.Tensor([0.0]))
3480+
self.register_buffer("clip_valn", torch.Tensor([0.0]))
34793481

34803482
def forward(self, input_tensor):
3481-
scales = input_tensor.abs().max(dim=-1, keepdim=True)[0]
3483+
self.clip_val = input_tensor.abs().max(dim=-1, keepdim=True)[0]
3484+
self.clip_valn = -self.clip_val
34823485
levels = 2 ** (self.num_bits - 1) - 1
3483-
scales.clamp_(min=1e-5).div_(levels)
3486+
scales = self.clip_val.clamp(min=1e-5).div(levels)
34843487
input_tensor.div_(scales).round_().mul_(scales)
34853488
return input_tensor
34863489

0 commit comments

Comments
 (0)