|
24 | 24 |
|
25 | 25 | # Local |
26 | 26 | from fms_mo.quant_refactor.base_quant import Qscheme, Quantizer |
| 27 | +from fms_mo.quant_refactor.per_channel_ste import ( |
| 28 | + PerChannelSTE_PTnative, |
| 29 | + PerChannelSTEQmax, |
| 30 | + PerChannelSTEQmax_PTnative, |
| 31 | +) |
27 | 32 | from fms_mo.quant_refactor.per_tensor_ste import ( |
28 | 33 | PerTensorSTE_PTnative, |
29 | 34 | PerTensorSTEQmax, |
@@ -123,13 +128,29 @@ def set_quantizer(self): |
123 | 128 | self.copy_legacy_vars() # copy align_zero |
124 | 129 |
|
125 | 130 | if self.use_PT_native_Qfunc: |
126 | | - if self.extend_act_range: |
127 | | - self.quantizer_name = "QmaxExtend" |
128 | | - self.set_clip_ratio() |
129 | | - self.quantizer = QmaxExtendRangeSTE_PTnative |
| 131 | + if self.minmax: |
| 132 | + self.quantizer_name = "Qminmax" |
| 133 | + if self.perGrp: |
| 134 | + # self.quantizer = QmaxPerGrpSTE_PTnative |
| 135 | + pass |
| 136 | + else: |
| 137 | + self.quantizer = ( |
| 138 | + QmaxPerChSTE_PTnative |
| 139 | + if self.perCh |
| 140 | + else PerTensorSTEQmax_PTnative |
| 141 | + ) |
130 | 142 | else: |
131 | | - self.quantizer_name = "Qminmax" if self.minmax else "Qmax" |
132 | | - self.quantizer = PerTensorSTEQmax_PTnative |
| 143 | + if self.extend_act_range: |
| 144 | + self.quantizer_name = "QmaxExtend" |
| 145 | + self.set_clip_ratio() |
| 146 | + self.quantizer = QmaxExtendRangeSTE_PTnative |
| 147 | + else: |
| 148 | + self.quantizer_name = "Qminmax" if self.minmax else "Qmax" |
| 149 | + self.quantizer = ( |
| 150 | + QmaxPerChSTE_PTnative |
| 151 | + if self.perCh |
| 152 | + else PerTensorSTEQmax_PTnative |
| 153 | + ) |
133 | 154 | else: |
134 | 155 | if self.minmax: |
135 | 156 | self.quantizer_name = "Qminmax" |
@@ -171,14 +192,14 @@ def forward(self, input_tensor: torch.FloatTensor) -> torch.Tensor: |
171 | 192 | if self.perCh: |
172 | 193 | if self.minmax: |
173 | 194 | clip_val_new = torch.max( |
174 | | - input_tensor.reshape([self.perCh, -1]), dim=1 |
| 195 | + input_tensor.reshape([self.qscheme.Nch, -1]), dim=1 |
175 | 196 | ).values |
176 | 197 | clip_valn_new = torch.min( |
177 | | - input_tensor.reshape([self.perCh, -1]), dim=1 |
| 198 | + input_tensor.reshape([self.qscheme.Nch, -1]), dim=1 |
178 | 199 | ).values |
179 | 200 | else: |
180 | 201 | clip_val_new = torch.max( |
181 | | - input_tensor.abs().reshape([self.perCh, -1]), dim=1 |
| 202 | + input_tensor.abs().reshape([self.qscheme.Nch, -1]), dim=1 |
182 | 203 | ).values |
183 | 204 | clip_valn_new = -clip_val_new |
184 | 205 | assert ( |
@@ -487,50 +508,32 @@ def calc_qparams( |
487 | 508 | return n_levels, clip_val, scale, zero_point, qint_min, qint_max, qint_dtype |
488 | 509 |
|
489 | 510 |
|
490 | | -# Placeholder classes for PerCh/PerGp - need to rework # |
491 | | -class QmaxPerChSTE_new(torch.autograd.Function): |
| 511 | +class QmaxPerChSTE_new(PerChannelSTEQmax): |
492 | 512 | """ |
493 | 513 | Max with zero alignment (symmetric) |
494 | 514 | "dequantize=False" option is functional |
495 | 515 | """ |
496 | 516 |
|
497 | 517 | @staticmethod |
498 | | - def forward( |
499 | | - ctx, |
500 | | - input_tensor, |
501 | | - num_bits, |
502 | | - _dequantize, |
503 | | - inplace, |
504 | | - _cvn, |
505 | | - cv, |
506 | | - align_zero, |
507 | | - ) -> torch.FloatTensor: |
| 518 | + def backward(ctx, grad_output): |
508 | 519 | """ |
509 | | - TODO (bmgroth): docstring |
| 520 | + Backward function for Qmax Per Channel STE |
| 521 | +
|
| 522 | + Args: |
| 523 | + ctx (torch.autograd.Function): Context object. |
| 524 | + grad_output (torch.FloatTensor): Gradient to clip |
| 525 | +
|
| 526 | + Returns: |
| 527 | + [torch.FloatTensor, None,...,None]: Gradients |
510 | 528 | """ |
511 | | - if inplace: |
512 | | - ctx.mark_dirty(input) |
513 | | - scale = (2**num_bits - 2) if align_zero else (2**num_bits - 1) |
514 | | - zero_point = 0.0 |
515 | | - _clip_val = cv |
516 | | - # here use symmetric similar to sawbperCh code |
517 | | - _nspace = 2**num_bits - 2 # lose one level |
518 | | - int_l = -(2 ** (num_bits - 1)) + 1 |
519 | | - int_u = -int_l # symmetric |
520 | | - scale = ( |
521 | | - cv * 2 / (2**num_bits - 2) |
522 | | - ) # original SAWB assumes odd number of bins when calc clip_val |
523 | | - zero_point = torch.zeros_like(scale) # centers around 0 and align 0 |
524 | | - # FIXME, fake quantize function only support float. |
525 | | - output = torch.fake_quantize_per_channel_affine( |
526 | | - input_tensor.float(), |
527 | | - scale.float(), |
528 | | - zero_point.float(), |
529 | | - axis=0, |
530 | | - quant_min=int_l, |
531 | | - quant_max=int_u, |
532 | | - ).to(input_tensor.dtype) |
533 | | - return output |
| 529 | + return grad_output, None, None, None, None, None, None |
| 530 | + |
| 531 | + |
| 532 | +class QmaxPerChSTE_PTnative(PerChannelSTEQmax_PTnative): |
| 533 | + """ |
| 534 | + Max with zero alignment (symmetric) |
| 535 | + "dequantize=False" option is functional |
| 536 | + """ |
534 | 537 |
|
535 | 538 | @staticmethod |
536 | 539 | def backward(ctx, grad_output): |
@@ -606,31 +609,11 @@ def backward(ctx, grad_output): |
606 | 609 | return grad_output, None, None, None, None, None, None |
607 | 610 |
|
608 | 611 |
|
609 | | -class QminmaxPerChSTE_new(torch.autograd.Function): |
| 612 | +class QminmaxPerChSTE_new(PerChannelSTEQmax): |
610 | 613 | """ |
611 | 614 | per channel minmax with zero alignment (asymmetric) |
612 | 615 | """ |
613 | 616 |
|
614 | | - @staticmethod |
615 | | - def forward( |
616 | | - ctx, input_tensor, num_bits, _dequantize, inplace, cv, cvn, align_zero |
617 | | - ) -> torch.FloatTensor: |
618 | | - """ |
619 | | - TODO (bmgroth): docstring |
620 | | - """ |
621 | | - if inplace: |
622 | | - ctx.mark_dirty(input_tensor) |
623 | | - cv, cvn = cv.to(input_tensor.dtype), cvn.to(input_tensor.dtype) |
624 | | - scale = (2**num_bits - 1) / (cv - cvn) |
625 | | - zero_point = cvn * scale |
626 | | - if align_zero: |
627 | | - zero_point = torch.round(zero_point) |
628 | | - output = (input_tensor.clamp(cvn[:, None], cv[:, None]) - cvn[:, None]) * scale[ |
629 | | - :, None |
630 | | - ] |
631 | | - output = (torch.round(output) + zero_point[:, None]) / scale[:, None] |
632 | | - return output |
633 | | - |
634 | 617 | @staticmethod |
635 | 618 | def backward(ctx, grad_output): |
636 | 619 | """ |
|
0 commit comments