@@ -3110,14 +3110,29 @@ def forward(
31103110 ) # original SAWB assumes odd number of bins when calc clip_val
31113111 zero_point = torch .zeros_like (scale ) # centers around 0 and align 0
31123112 # FIXME, fake quantize function only support float.
3113- output = torch .fake_quantize_per_channel_affine (
3114- input .float (),
3115- scale .float (),
3116- zero_point .float (),
3117- axis = 0 ,
3118- quant_min = int_l ,
3119- quant_max = int_u ,
3120- ).to (input .dtype )
3113+
3114+ if dequantize :
3115+ output = torch .fake_quantize_per_channel_affine (
3116+ input .float (),
3117+ scale .float (),
3118+ zero_point .float (),
3119+ axis = 0 ,
3120+ quant_min = int_l ,
3121+ quant_max = int_u ,
3122+ ).to (input .dtype )
3123+ else :
3124+ output = (
3125+ torch .quantize_per_channel (
3126+ input .float (),
3127+ scale .float (),
3128+ zero_point .float (),
3129+ axis = 0 ,
3130+ dtype = torch .qint8 ,
3131+ )
3132+ .int_repr ()
3133+ .clamp (int_l , int_u )
3134+ )
3135+
31213136 return output
31223137
31233138 @staticmethod
@@ -3210,15 +3225,14 @@ def forward(
32103225 ctx .mark_dirty (input )
32113226 clip_val , clip_valn = clip_val .to (input .dtype ), clip_valn .to (input .dtype )
32123227 scale = (clip_val - clip_valn ) / (2 ** num_bits - 1 )
3213- zero_point = torch .round (- clip_valn / scale ).to (torch .int )
3228+ zero_point = torch .round (clip_valn / scale ).to (torch .int )
32143229
3215- output = input .clamp (clip_valn [:, None ], clip_val [:, None ])
3216- output = torch .round (output / scale [:, None ] - zero_point [:, None ])
3230+ output = torch .round (input / scale [:, None ] - zero_point [:, None ])
32173231 if dequantize :
32183232 output = (output + zero_point [:, None ]) * scale [:, None ]
32193233 else :
3220- n_half = 2 ** ( num_bits - 1 )
3221- output = ( output - n_half ). to ( torch . int8 )
3234+ output = output . to ( torch . uint8 )
3235+
32223236 return output
32233237
32243238 @staticmethod
0 commit comments