|
42 | 42 | # Local |
43 | 43 | from fms_mo.modules import QBmm, QLinear |
44 | 44 | from fms_mo.modules.conv import QConv2dPTQv2 |
| 45 | +from fms_mo.modules.linear import LinearFPxAcc, QLinearINT8Deploy |
45 | 46 | from fms_mo.quant.quantizers import ( |
46 | 47 | AdaRoundQuantizer, |
47 | 48 | Qdynamic, |
@@ -481,8 +482,118 @@ def __call__(self, mod, inputs, *args, **_kwargs): |
481 | 482 | assert not self.stop_after_rec |
482 | 483 |
|
483 | 484 |
|
484 | | -# this hook is meant for ptq_loss_func == 'fisher_diag' and to temp hold the "Q_out" of the module |
| 485 | +class HookRecPostQuantInOut(torch.nn.Module): |
| 486 | + """Another simplified hook to check post-quantized input/output, e.g. within +-127 for INT8.""" |
| 487 | + |
| 488 | + def __init__(self, cache_dict={}, mod_name=None): |
| 489 | + super().__init__() |
| 490 | + self.cache_dict = cache_dict |
| 491 | + self.mod_name = mod_name |
| 492 | + name_split = mod_name.split(".") |
| 493 | + self.lay_idx = int(name_split[3]) |
| 494 | + self.lay_key = name_split[6] |
| 495 | + |
| 496 | + self.cache_dev = "cpu" |
| 497 | + # prepare empty dict for later use |
| 498 | + self.cache_dict[mod_name] = {} |
| 499 | + self.fwd_mapping = { |
| 500 | + LinearFPxAcc: self.call_func_for_fpxacc, |
| 501 | + QLinear: self.call_func_for_qlinear, |
| 502 | + QLinearINT8Deploy: self.call_func_for_qlinear_int, |
| 503 | + torch.nn.Linear: self.call_func_for_nnlinear, |
| 504 | + } |
| 505 | + |
| 506 | + def call_func_for_fpxacc(self, mod, inputs, outputs, **_kwargs): |
| 507 | + raise NotImplementedError |
| 508 | + |
| 509 | + def call_func_for_qlinear(self, mod, inputs, outputs, **_kwargs): |
| 510 | + lay_idx = self.lay_idx |
| 511 | + lay_key = self.lay_key |
| 512 | + mod_name = self.mod_name |
| 513 | + cache_dict = self.cache_dict |
| 514 | + |
| 515 | + act_max = inputs[0].abs().amax(dim=[d for d in range(len(inputs[0].shape) - 1)]) |
| 516 | + # mod.smoothq_act_scale |
| 517 | + w_max = mod.weight.abs().max(dim=0, keepdim=True)[0].clamp(min=1e-5) |
| 518 | + is_smq_layer = not torch.all(act_max == 0).item() |
| 519 | + # smoothQ scale = smoothq_act_scale**alpha / weight_scale**(1.0 - alpha) |
| 520 | + # smoothq_scale = mod.get_smoothq_scale(inputs[0]) |
| 521 | + smoothq_scale = getattr(mod, "smq_scale", 1.0) |
| 522 | + # "smq_scale" only available in QLin_INT8 |
| 523 | + |
| 524 | + with torch.no_grad(): |
| 525 | + smoothed_inp = inputs[0] / smoothq_scale |
| 526 | + smoothed_w = mod.weight * smoothq_scale |
| 527 | + |
| 528 | + # this is assuming pertokenmax quantizer, NOTE calc quant scale after smoothing |
| 529 | + absmax = smoothed_inp.abs().max(dim=-1, keepdim=True)[0] |
| 530 | + qa_scale = absmax.clamp(min=1e-5) / 127 |
| 531 | + qinput = torch.round(smoothed_inp / qa_scale).clamp(-127, 127) |
| 532 | + # should clamp to -128? |
| 533 | + if mod.qa_mode == "pertokenmax": |
| 534 | + # doesnt implement dequant=False yet, do it manually |
| 535 | + cva = mod.quantize_feature.clip_val |
| 536 | + qa_scale = cva.clamp(min=1e-5).div(127) |
| 537 | + qinput = smoothed_inp.div(qa_scale).round() |
| 538 | + else: |
| 539 | + mod.quantize_feature.dequantize = False |
| 540 | + qinput = mod.quantize_feature(smoothed_inp) |
| 541 | + mod.quantize_feature.dequantize = True |
| 542 | + |
| 543 | + # also record quantized, smoothed W in INT8, calc both maxperCh and SAWBperCh |
| 544 | + cvw = mod.quantize_weight.clip_val |
| 545 | + scale_w = cvw / 127 |
| 546 | + mod.quantize_weight.dequantize = False |
| 547 | + qw = mod.quantize_weight(smoothed_w) |
| 548 | + mod.quantize_weight.dequantize = True |
| 549 | + |
| 550 | + # inputs is a tuple, QLinear only has 1 valid input |
| 551 | + cache_dict[mod_name]["input"] = inputs[0].to(self.cache_dev) |
| 552 | + cache_dict[mod_name]["cva"] = cva.to(self.cache_dev) |
| 553 | + cache_dict[mod_name]["cvw"] = cvw.to(self.cache_dev) |
| 554 | + cache_dict[mod_name]["smoothed_input"] = smoothed_inp.to(self.cache_dev) |
| 555 | + cache_dict[mod_name]["smoothed_weight"] = smoothed_w.to(self.cache_dev) |
| 556 | + cache_dict[mod_name]["qinput"] = qinput.to(self.cache_dev) |
| 557 | + # NOTE in INT8, *scales if need dQ |
| 558 | + cache_dict[mod_name]["qweight"] = qw.to(self.cache_dev) |
| 559 | + # torch.round(smoothed_w.T/scale_w).clamp(-127, 127).to(self.cache_dev) |
| 560 | + # cache_dict[mod_name]["qoutput"] = outputs.to(self.cache_dev) |
| 561 | + |
| 562 | + def call_func_for_qlinear_int(self, mod, inputs, outputs, **_kwargs): |
| 563 | + smoothq_scale = getattr(mod, "smq_scale", 1.0) |
| 564 | + mod_name = self.mod_name |
| 565 | + cache_dict = self.cache_dict |
| 566 | + with torch.no_grad(): |
| 567 | + if mod.useDynMaxQfunc in [-1, -2]: |
| 568 | + qinput = mod.qa_dynamic_max_qfunc(inputs[0]) |
| 569 | + elif mod.use_fake_zero_shift: |
| 570 | + qinput = mod.qa_dyn_max_fake_zero_shift(inputs[0]) |
| 571 | + elif mod.usePTnativeQfunc: |
| 572 | + qinput = mod.qa_raw_qfunc(inputs[0]) |
| 573 | + else: |
| 574 | + qinput = mod.qa_fmo_mo_qfunc(inputs[0]) |
| 575 | + |
| 576 | + # inputs is a tuple, QLinear only has 1 valid input |
| 577 | + cache_dict[mod_name]["input"] = inputs[0].to(self.cache_dev) |
| 578 | + cache_dict[mod_name]["cva"] = mod.cvs[0].to(self.cache_dev) |
| 579 | + cache_dict[mod_name]["cvw"] = mod.cvs[2].to(self.cache_dev) |
| 580 | + cache_dict[mod_name]["qinput"] = qinput.to(self.cache_dev) |
| 581 | + cache_dict[mod_name]["qweight"] = mod.weight.to(self.cache_dev) |
| 582 | + |
| 583 | + def call_func_for_nnlinear(self, mod, inputs, outputs, **_kwargs): |
| 584 | + mod_name = self.mod_name |
| 585 | + cache_dict = self.cache_dict |
| 586 | + cache_dict[mod_name]["input"] = inputs[0].to(self.cache_dev) |
| 587 | + cache_dict[mod_name]["weight"] = mod.weight.to(self.cache_dev) |
| 588 | + |
| 589 | + def __call__(self, mod, inputs, outputs, **_kwargs): |
| 590 | + self.fwd_mapping[type(mod)](mod, inputs, outputs, **_kwargs) |
| 591 | + |
| 592 | + |
485 | 593 | class PTQHookRecQOut(nn.Module): |
| 594 | + """This hook is for ptq_loss_func == 'fisher_diag' and will temporarily hold the "Q_out" of the |
| 595 | + module""" |
| 596 | + |
486 | 597 | def __init__(self, qcfg): |
487 | 598 | super().__init__() |
488 | 599 | self.qcfg = qcfg |
|
0 commit comments