@@ -73,8 +73,14 @@ def __init__(
7373 num_bits .item () if isinstance (num_bits , torch .Tensor ) else num_bits
7474 )
7575 # turn clips into tensors (from python float)
76- self .clip_low = torch .Tensor ([clip_low ])
77- self .clip_high = torch .Tensor ([clip_high ])
76+ self .clip_low = (
77+ torch .Tensor ([clip_low ]) if not isinstance (clip_low , torch .Tensor )
78+ else clip_low
79+ )
80+ self .clip_high = (
81+ torch .Tensor ([clip_high ]) if not isinstance (clip_high , torch .Tensor )
82+ else clip_high
83+ )
7884 self .symmetric_zp0 = False
7985 self .qscheme = qscheme
8086 self .set_quant_bounds ()
@@ -102,11 +108,11 @@ def get_setup(self):
102108 """
103109 return (
104110 self .num_bits_int ,
105- self .clip_low . item () ,
106- self .clip_high . item () ,
111+ self .clip_low ,
112+ self .clip_high ,
107113 self .n_levels .item (),
108- self .scale . item () ,
109- self .zero_point . item () ,
114+ self .scale ,
115+ self .zero_point ,
110116 self .quant_min ,
111117 self .quant_max ,
112118 self .qscheme ,
@@ -127,7 +133,7 @@ def set_quant_bounds(self):
127133 self .scale = (self .clip_high - self .clip_low ) / (self .n_levels )
128134 # this "ZP" will map the float value we choose (clip_low in this case) to the 0 bin
129135 self .zero_point = (
130- torch .tensor ( 0 )
136+ torch .zeros ( self . scale . shape , dtype = torch . int )
131137 if (self .is_symmetric )
132138 else torch .round (- self .clip_low / self .scale ).to (torch .int )
133139 )
@@ -138,7 +144,7 @@ def set_quant_range(self):
138144 """
139145 Set quantization integer range based on member variables
140146 """
141- if self .is_symmetric and self .zero_point == 0 :
147+ if self .is_symmetric and torch . sum ( self .zero_point ) == 0 :
142148 # Either [-8,7];[-128,127] for non-symmetric or [-7,7];[-127,127] for qlevel_lowering
143149 self .quant_min , self .quant_max = (
144150 - (2 ** (self .num_bits - 1 )) + self .symmetric_nlevel ,
@@ -264,7 +270,7 @@ def get_torch_dtype(self):
264270 if self .is_single_sided :
265271 signed = False
266272 else :
267- signed = (self .zero_point == 0 ).item ()
273+ signed = (torch . sum ( self .zero_point ) == 0 ).item ()
268274 return self .dtype_dict .get (
269275 (self .num_bits_int , signed )
270276 ) # NOTE .item() won't work for perCh
@@ -282,24 +288,53 @@ def forward(self, tensor: torch.FloatTensor):
282288 Returns:
283289 torch.FloatTensor: Quantized or dequantized tensor.
284290 """
291+
285292 if self .dequantize :
286- output = torch .fake_quantize_per_tensor_affine (
287- tensor , self .scale , self .zero_point , self .quant_min , self .quant_max
288- )
293+ if self .qscheme .Nch : # Per Channel
294+ output = torch .fake_quantize_per_channel_affine (
295+ tensor ,
296+ self .scale .float (),
297+ self .zero_point .float (),
298+ self .qscheme .axis ,
299+ self .quant_min ,
300+ self .quant_max ,
301+ )
302+ elif self .qscheme .Ngrp : # Per Group
303+ pass
304+ else : # Per Tensor
305+ output = torch .fake_quantize_per_tensor_affine (
306+ tensor ,
307+ self .scale ,
308+ self .zero_point ,
309+ self .quant_min ,
310+ self .quant_max ,
311+ )
289312 else :
290313 dtype = self .get_torch_dtype ()
291314 if dtype :
292- # Clamp to [quant_min, quant_max] in case we are storing int4 into int8/uint8 tensor
293- output = (
294- torch .quantize_per_tensor (
295- tensor , self .scale , self .zero_point , dtype
315+ if self .qscheme .q_unit == "perCh" :
316+ output = torch .quantize_per_channel (
317+ tensor ,
318+ self .scale ,
319+ self .zero_point ,
320+ self .qscheme .axis ,
321+ dtype ,
296322 )
297- .int_repr ()
298- .clamp (self .quant_min , self .quant_max )
299- )
323+ elif self .qscheme .q_unit == "perGrp" :
324+ raise RuntimeError ("TorchQuantizer forward not implemented for perGrp" )
325+ else : # Per Tensor
326+ output = torch .quantize_per_tensor (
327+ tensor ,
328+ self .scale ,
329+ self .zero_point ,
330+ dtype ,
331+ )
332+ # Clamp required if storing int4 into int8 tensor (no PT support for int4)
333+ output = output .int_repr ().clamp (self .quant_min , self .quant_max )
300334 else :
301335 raise RuntimeError (
302336 f"num_bits { self .num_bits } and sign { (self .zero_point == 0 ).item ()} "
303337 "combination results in unavailable dtype."
304338 )
339+
305340 return output
0 commit comments