Skip to content

Commit 454391f

Browse files
committed
feat: Implemented dequant for Qmax original perCh STEs
Signed-off-by: Brandon Groth <[email protected]>
1 parent 63396cf commit 454391f

File tree

1 file changed

+27
-13
lines changed

1 file changed

+27
-13
lines changed

fms_mo/quant_refactor/quantizers_new.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3110,14 +3110,29 @@ def forward(
31103110
) # original SAWB assumes odd number of bins when calc clip_val
31113111
zero_point = torch.zeros_like(scale) # centers around 0 and align 0
31123112
# FIXME, fake quantize function only support float.
3113-
output = torch.fake_quantize_per_channel_affine(
3114-
input.float(),
3115-
scale.float(),
3116-
zero_point.float(),
3117-
axis=0,
3118-
quant_min=int_l,
3119-
quant_max=int_u,
3120-
).to(input.dtype)
3113+
3114+
if dequantize:
3115+
output = torch.fake_quantize_per_channel_affine(
3116+
input.float(),
3117+
scale.float(),
3118+
zero_point.float(),
3119+
axis=0,
3120+
quant_min=int_l,
3121+
quant_max=int_u,
3122+
).to(input.dtype)
3123+
else:
3124+
output = (
3125+
torch.quantize_per_channel(
3126+
input.float(),
3127+
scale.float(),
3128+
zero_point.float(),
3129+
axis=0,
3130+
dtype=torch.qint8,
3131+
)
3132+
.int_repr()
3133+
.clamp(int_l, int_u)
3134+
)
3135+
31213136
return output
31223137

31233138
@staticmethod
@@ -3210,15 +3225,14 @@ def forward(
32103225
ctx.mark_dirty(input)
32113226
clip_val, clip_valn = clip_val.to(input.dtype), clip_valn.to(input.dtype)
32123227
scale = (clip_val - clip_valn) / (2**num_bits - 1)
3213-
zero_point = torch.round(-clip_valn / scale).to(torch.int)
3228+
zero_point = torch.round(clip_valn / scale).to(torch.int)
32143229

3215-
output = input.clamp(clip_valn[:, None], clip_val[:, None])
3216-
output = torch.round(output / scale[:, None] - zero_point[:, None])
3230+
output = torch.round(input / scale[:, None] - zero_point[:, None])
32173231
if dequantize:
32183232
output = (output + zero_point[:, None]) * scale[:, None]
32193233
else:
3220-
n_half = 2 ** (num_bits - 1)
3221-
output = (output - n_half).to(torch.int8)
3234+
output = output.to(torch.uint8)
3235+
32223236
return output
32233237

32243238
@staticmethod

0 commit comments

Comments
 (0)