Skip to content

Commit a60a4b8

Browse files
add a new hook for checking post-quant in/out
Signed-off-by: cliu-us <[email protected]>
1 parent 952b6d4 commit a60a4b8

File tree

1 file changed

+112
-1
lines changed

1 file changed

+112
-1
lines changed

fms_mo/quant/ptq.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
# Local
4343
from fms_mo.modules import QBmm, QLinear
4444
from fms_mo.modules.conv import QConv2dPTQv2
45+
from fms_mo.modules.linear import LinearFPxAcc, QLinearINT8Deploy
4546
from fms_mo.quant.quantizers import (
4647
AdaRoundQuantizer,
4748
Qdynamic,
@@ -481,8 +482,118 @@ def __call__(self, mod, inputs, *args, **_kwargs):
481482
assert not self.stop_after_rec
482483

483484

484-
# this hook is meant for ptq_loss_func == 'fisher_diag' and to temp hold the "Q_out" of the module
485+
class HookRecPostQuantInOut(torch.nn.Module):
486+
"""Another simplified hook to check post-quantized input/output, e.g. within +-127 for INT8."""
487+
488+
def __init__(self, cache_dict={}, mod_name=None):
489+
super().__init__()
490+
self.cache_dict = cache_dict
491+
self.mod_name = mod_name
492+
name_split = mod_name.split(".")
493+
self.lay_idx = int(name_split[3])
494+
self.lay_key = name_split[6]
495+
496+
self.cache_dev = "cpu"
497+
# prepare empty dict for later use
498+
self.cache_dict[mod_name] = {}
499+
self.fwd_mapping = {
500+
LinearFPxAcc: self.call_func_for_fpxacc,
501+
QLinear: self.call_func_for_qlinear,
502+
QLinearINT8Deploy: self.call_func_for_qlinear_int,
503+
torch.nn.Linear: self.call_func_for_nnlinear,
504+
}
505+
506+
def call_func_for_fpxacc(self, mod, inputs, outputs, **_kwargs):
507+
raise NotImplementedError
508+
509+
def call_func_for_qlinear(self, mod, inputs, outputs, **_kwargs):
510+
lay_idx = self.lay_idx
511+
lay_key = self.lay_key
512+
mod_name = self.mod_name
513+
cache_dict = self.cache_dict
514+
515+
act_max = inputs[0].abs().amax(dim=[d for d in range(len(inputs[0].shape) - 1)])
516+
# mod.smoothq_act_scale
517+
w_max = mod.weight.abs().max(dim=0, keepdim=True)[0].clamp(min=1e-5)
518+
is_smq_layer = not torch.all(act_max == 0).item()
519+
# smoothQ scale = smoothq_act_scale**alpha / weight_scale**(1.0 - alpha)
520+
# smoothq_scale = mod.get_smoothq_scale(inputs[0])
521+
smoothq_scale = getattr(mod, "smq_scale", 1.0)
522+
# "smq_scale" only available in QLin_INT8
523+
524+
with torch.no_grad():
525+
smoothed_inp = inputs[0] / smoothq_scale
526+
smoothed_w = mod.weight * smoothq_scale
527+
528+
# this is assuming pertokenmax quantizer, NOTE calc quant scale after smoothing
529+
absmax = smoothed_inp.abs().max(dim=-1, keepdim=True)[0]
530+
qa_scale = absmax.clamp(min=1e-5) / 127
531+
qinput = torch.round(smoothed_inp / qa_scale).clamp(-127, 127)
532+
# should clamp to -128?
533+
if mod.qa_mode == "pertokenmax":
534+
# doesnt implement dequant=False yet, do it manually
535+
cva = mod.quantize_feature.clip_val
536+
qa_scale = cva.clamp(min=1e-5).div(127)
537+
qinput = smoothed_inp.div(qa_scale).round()
538+
else:
539+
mod.quantize_feature.dequantize = False
540+
qinput = mod.quantize_feature(smoothed_inp)
541+
mod.quantize_feature.dequantize = True
542+
543+
# also record quantized, smoothed W in INT8, calc both maxperCh and SAWBperCh
544+
cvw = mod.quantize_weight.clip_val
545+
scale_w = cvw / 127
546+
mod.quantize_weight.dequantize = False
547+
qw = mod.quantize_weight(smoothed_w)
548+
mod.quantize_weight.dequantize = True
549+
550+
# inputs is a tuple, QLinear only has 1 valid input
551+
cache_dict[mod_name]["input"] = inputs[0].to(self.cache_dev)
552+
cache_dict[mod_name]["cva"] = cva.to(self.cache_dev)
553+
cache_dict[mod_name]["cvw"] = cvw.to(self.cache_dev)
554+
cache_dict[mod_name]["smoothed_input"] = smoothed_inp.to(self.cache_dev)
555+
cache_dict[mod_name]["smoothed_weight"] = smoothed_w.to(self.cache_dev)
556+
cache_dict[mod_name]["qinput"] = qinput.to(self.cache_dev)
557+
# NOTE in INT8, *scales if need dQ
558+
cache_dict[mod_name]["qweight"] = qw.to(self.cache_dev)
559+
# torch.round(smoothed_w.T/scale_w).clamp(-127, 127).to(self.cache_dev)
560+
# cache_dict[mod_name]["qoutput"] = outputs.to(self.cache_dev)
561+
562+
def call_func_for_qlinear_int(self, mod, inputs, outputs, **_kwargs):
563+
smoothq_scale = getattr(mod, "smq_scale", 1.0)
564+
mod_name = self.mod_name
565+
cache_dict = self.cache_dict
566+
with torch.no_grad():
567+
if mod.useDynMaxQfunc in [-1, -2]:
568+
qinput = mod.qa_dynamic_max_qfunc(inputs[0])
569+
elif mod.use_fake_zero_shift:
570+
qinput = mod.qa_dyn_max_fake_zero_shift(inputs[0])
571+
elif mod.usePTnativeQfunc:
572+
qinput = mod.qa_raw_qfunc(inputs[0])
573+
else:
574+
qinput = mod.qa_fmo_mo_qfunc(inputs[0])
575+
576+
# inputs is a tuple, QLinear only has 1 valid input
577+
cache_dict[mod_name]["input"] = inputs[0].to(self.cache_dev)
578+
cache_dict[mod_name]["cva"] = mod.cvs[0].to(self.cache_dev)
579+
cache_dict[mod_name]["cvw"] = mod.cvs[2].to(self.cache_dev)
580+
cache_dict[mod_name]["qinput"] = qinput.to(self.cache_dev)
581+
cache_dict[mod_name]["qweight"] = mod.weight.to(self.cache_dev)
582+
583+
def call_func_for_nnlinear(self, mod, inputs, outputs, **_kwargs):
584+
mod_name = self.mod_name
585+
cache_dict = self.cache_dict
586+
cache_dict[mod_name]["input"] = inputs[0].to(self.cache_dev)
587+
cache_dict[mod_name]["weight"] = mod.weight.to(self.cache_dev)
588+
589+
def __call__(self, mod, inputs, outputs, **_kwargs):
590+
self.fwd_mapping[type(mod)](mod, inputs, outputs, **_kwargs)
591+
592+
485593
class PTQHookRecQOut(nn.Module):
594+
"""This hook is for ptq_loss_func == 'fisher_diag' and will temporarily hold the "Q_out" of the
595+
module"""
596+
486597
def __init__(self, qcfg):
487598
super().__init__()
488599
self.qcfg = qcfg

0 commit comments

Comments
 (0)