1616Functions to create quantizers for activation and weights. Called from Qmodule level.
1717"""
1818
19+ # Third Party
1920import torch
2021
2122# Local
3839
3940
4041def get_activation_quantizer_new (
41- qa_mode :str = "PACT" ,
42- nbits :int = 32 ,
43- clip_val :torch .FloatTensor = None ,
44- clip_valn :torch .FloatTensor = None ,
45- non_neg :bool = False ,
46- align_zero :bool = True , # pylint: disable=unused-argument
47- extend_act_range :bool = False ,
48- use_PT_native_Qfunc :bool = False ,
49- use_subnormal :bool = False ,
42+ qa_mode : str = "PACT" ,
43+ nbits : int = 32 ,
44+ clip_val : torch .FloatTensor = None ,
45+ clip_valn : torch .FloatTensor = None ,
46+ non_neg : bool = False ,
47+ align_zero : bool = True , # pylint: disable=unused-argument
48+ extend_act_range : bool = False ,
49+ use_PT_native_Qfunc : bool = False ,
50+ use_subnormal : bool = False ,
5051):
5152 """Return a quantizer for activation quantization
5253 Regular quantizers:
@@ -212,16 +213,16 @@ def get_activation_quantizer_new(
212213
213214
214215def get_weight_quantizer_new (
215- qw_mode :str = "SAWB+" ,
216- nbits :int = 32 ,
217- clip_val :torch .FloatTensor = None ,
218- clip_valn :torch .FloatTensor = None ,
219- align_zero :bool = True ,
220- w_shape :torch .Size = None ,
221- recompute :bool = False , # pylint: disable=unused-argument
222- perGp :int = None ,
223- use_PT_native_Qfunc :bool = False ,
224- use_subnormal :bool = False ,
216+ qw_mode : str = "SAWB+" ,
217+ nbits : int = 32 ,
218+ clip_val : torch .FloatTensor = None ,
219+ clip_valn : torch .FloatTensor = None ,
220+ align_zero : bool = True ,
221+ w_shape : torch .Size = None ,
222+ recompute : bool = False , # pylint: disable=unused-argument
223+ perGp : int = None ,
224+ use_PT_native_Qfunc : bool = False ,
225+ use_subnormal : bool = False ,
225226):
226227 """Return a quantizer for weight quantization
227228 Regular quantizers:
@@ -236,13 +237,7 @@ def get_weight_quantizer_new(
236237 Ngrp = (
237238 [w_shape [0 ] * w_shape [1 ] // perGp , perGp ] if "perGp" in qw_mode else False
238239 ) # store clip_val size and group size
239- unit = (
240- "perCh"
241- if Nch is not False
242- else "perGrp"
243- if perGp is not None
244- else "perT"
245- )
240+ unit = "perCh" if Nch is not False else "perGrp" if perGp is not None else "perT"
246241 if "sawb" in qw_mode :
247242 clipSTE = "+" in qw_mode
248243 weight_quantizer = SAWB_new (
@@ -260,7 +255,6 @@ def get_weight_quantizer_new(
260255 use_PT_native_Qfunc = use_PT_native_Qfunc ,
261256 )
262257 elif "max" in qw_mode :
263-
264258 weight_quantizer = Qmax_new (
265259 nbits ,
266260 Qscheme = Qscheme (
0 commit comments