|
39 | 39 | qscheme_per_tensor = Qscheme( |
40 | 40 | unit="perT", |
41 | 41 | symmetric=False, |
42 | | - Nch=None, |
43 | | - Ngrp=None, |
44 | 42 | single_sided=False, |
45 | 43 | qlevel_lowering=False, |
46 | 44 | ) |
@@ -560,75 +558,75 @@ def calc_qparams( |
560 | 558 | # Placeholder classes for PerCh - need to rework # |
561 | 559 | class SAWBPlusZeroPerChSTE_new(PerChannelSTESAWB): |
562 | 560 | """ |
563 | | - per-channel SAWB with zero alignment, can use 15 or 16 bins, i.e. [-7,7] or [-7,8] |
| 561 | + per-channel SAWB with zero alignment, ca,8n use 15 or 16 bins, i.e. [-7,7] or [-7] |
564 | 562 | """ |
565 | 563 |
|
566 | | - @staticmethod |
567 | | - def forward( |
568 | | - ctx, |
569 | | - input_tensor: torch.FloatTensor, |
570 | | - num_bits: torch.IntTensor, |
571 | | - _clip_valn: torch.FloatTensor = clip_valn_default, |
572 | | - clip_val: torch.FloatTensor = clip_val_default, |
573 | | - dequantize: bool = True, |
574 | | - _symmetric: bool = False, |
575 | | - _qlevel_lowering: bool = False, |
576 | | - _use_code: bool = False, |
577 | | - ): |
578 | | - """ |
579 | | - Forward function for SAWBPlusZeroPerChSTE |
580 | | -
|
581 | | - Args: |
582 | | - ctx (torch.autograd.Function): Forward/Backward context object. |
583 | | - input_tensor (torch.FloatTensor): Tensor to be quantized. |
584 | | - num_bits (torch.IntTensor): Number of bit for quantization. |
585 | | - clip_valn (torch.FloatTensor): Lower clip value bound. |
586 | | - clip_val (torch.FloatTensor): Upper clip value bound. |
587 | | - dequantize (bool, optional): Return dequantized or int tensor. Defaults to True. |
588 | | - symmetric (bool, optional): Specify if clip values are symmetric. Defaults to False. |
589 | | - qlevel_lowering (bool, optional): Specify lowering of quantized levels. |
590 | | - Defaults to True. |
591 | | - use_code (bool, optional): Specify using SAWB code. Defaults to False. |
592 | | -
|
593 | | - Returns: |
594 | | - torch.Tensor: Dequantized or Quantized output tensor. |
595 | | - """ |
596 | | - # assert num_bits in [4, 8], "only implemented for 4bit and 8bit" |
597 | | - |
598 | | - SAWBcode_mapping = {8: 803, 4: 403, 2: 103} |
599 | | - num_bits_int = ( |
600 | | - num_bits.item() if isinstance(num_bits, torch.Tensor) else num_bits |
601 | | - ) |
602 | | - clip_val, _ = sawb_params_code( |
603 | | - num_bits_int, SAWBcode_mapping[num_bits_int], input_tensor, perCh=True |
604 | | - ) |
605 | | - |
606 | | - _nspace = 2**num_bits - 2 # + objSAWB.use16bins # Ignore 16bins for now |
607 | | - int_l = -(2 ** (num_bits - 1)) + 1 |
608 | | - int_u = -int_l # + objSAWB.use16bins # Ignore 16bins for now |
609 | | - |
610 | | - scale = clip_val * 2 / (2**num_bits - 2) |
611 | | - # original SAWB assumes odd number of bins when calc clip_val |
612 | | - zero_point = torch.zeros_like(scale) # SAWB always centers around 0 and align 0 |
613 | | - |
614 | | - if dequantize: |
615 | | - output = torch.fake_quantize_per_channel_affine( |
616 | | - input_tensor.float(), |
617 | | - scale.float(), |
618 | | - zero_point.float(), |
619 | | - axis=0, |
620 | | - quant_min=int_l, |
621 | | - quant_max=int_u, |
622 | | - ).to( |
623 | | - clip_val.dtype |
624 | | - ) # NOTE return will be a fp32 tensor; function only support float() |
625 | | - else: |
626 | | - output = torch.quantize_per_channel( |
627 | | - input_tensor, scale, zero_point, 0, torch.qint8 |
628 | | - ).int_repr() |
629 | | - # NOTE return will be a torch.int8 tensor |
630 | | - |
631 | | - return output |
| 564 | + # @staticmethod |
| 565 | + # def forward( |
| 566 | + # ctx, |
| 567 | + # input_tensor: torch.FloatTensor, |
| 568 | + # num_bits: torch.IntTensor, |
| 569 | + # _clip_valn: torch.FloatTensor = clip_valn_default, |
| 570 | + # clip_val: torch.FloatTensor = clip_val_default, |
| 571 | + # dequantize: bool = True, |
| 572 | + # _symmetric: bool = False, |
| 573 | + # _qlevel_lowering: bool = False, |
| 574 | + # _use_code: bool = False, |
| 575 | + # ): |
| 576 | + # """ |
| 577 | + # Forward function for SAWBPlusZeroPerChSTE |
| 578 | + |
| 579 | + # Args: |
| 580 | + # ctx (torch.autograd.Function): Forward/Backward context object. |
| 581 | + # input_tensor (torch.FloatTensor): Tensor to be quantized. |
| 582 | + # num_bits (torch.IntTensor): Number of bit for quantization. |
| 583 | + # clip_valn (torch.FloatTensor): Lower clip value bound. |
| 584 | + # clip_val (torch.FloatTensor): Upper clip value bound. |
| 585 | + # dequantize (bool, optional): Return dequantized or int tensor. Defaults to True. |
| 586 | + # symmetric (bool, optional): Specify if clip values are symmetric. Defaults to False. |
| 587 | + # qlevel_lowering (bool, optional): Specify lowering of quantized levels. |
| 588 | + # Defaults to True. |
| 589 | + # use_code (bool, optional): Specify using SAWB code. Defaults to False. |
| 590 | + |
| 591 | + # Returns: |
| 592 | + # torch.Tensor: Dequantized or Quantized output tensor. |
| 593 | + # """ |
| 594 | + # # assert num_bits in [4, 8], "only implemented for 4bit and 8bit" |
| 595 | + |
| 596 | + # SAWBcode_mapping = {8: 803, 4: 403, 2: 103} |
| 597 | + # num_bits_int = ( |
| 598 | + # num_bits.item() if isinstance(num_bits, torch.Tensor) else num_bits |
| 599 | + # ) |
| 600 | + # clip_val, _ = sawb_params_code( |
| 601 | + # num_bits_int, SAWBcode_mapping[num_bits_int], input_tensor, perCh=True |
| 602 | + # ) |
| 603 | + |
| 604 | + # _nspace = 2**num_bits - 2 # + objSAWB.use16bins # Ignore 16bins for now |
| 605 | + # int_l = -(2 ** (num_bits - 1)) + 1 |
| 606 | + # int_u = -int_l # + objSAWB.use16bins # Ignore 16bins for now |
| 607 | + |
| 608 | + # scale = clip_val * 2 / (2**num_bits - 2) |
| 609 | + # # original SAWB assumes odd number of bins when calc clip_val |
| 610 | + # zero_point = torch.zeros_like(scale) # SAWB always centers around 0 and align 0 |
| 611 | + |
| 612 | + # if dequantize: |
| 613 | + # output = torch.fake_quantize_per_channel_affine( |
| 614 | + # input_tensor.float(), |
| 615 | + # scale.float(), |
| 616 | + # zero_point.float(), |
| 617 | + # axis=0, |
| 618 | + # quant_min=int_l, |
| 619 | + # quant_max=int_u, |
| 620 | + # ).to( |
| 621 | + # clip_val.dtype |
| 622 | + # ) # NOTE return will be a fp32 tensor; function only support float() |
| 623 | + # else: |
| 624 | + # output = torch.quantize_per_channel( |
| 625 | + # input_tensor, scale, zero_point, 0, torch.qint8 |
| 626 | + # ).int_repr() |
| 627 | + # # NOTE return will be a torch.int8 tensor |
| 628 | + |
| 629 | + # return output |
632 | 630 |
|
633 | 631 | @staticmethod |
634 | 632 | def backward(ctx, grad_output): |
|
0 commit comments