Skip to content

Commit 33a899c

Browse files
committed
fix: num_bit_int fix for SAWBPlusZeroPerChSTE
Signed-off-by: Brandon Groth <[email protected]>
1 parent bfadd63 commit 33a899c

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

fms_mo/quant/quantizers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -510,9 +510,10 @@ def forward(
510510

511511
if istraining:
512512
# only recalc clipvals under training mode
513+
num_bits_int = num_bits.item() if isinstance(num_bits, torch.Tensor) else num_bits
513514
SAWBcode_mapping = {8: 803, 4: 403, 2: 103}
514515
if num_bits in [2, 4, 8]:
515-
sawb_code = SAWBcode_mapping[num_bits]
516+
sawb_code = SAWBcode_mapping[num_bits_int]
516517
clip_val, _ = sawb_params_code(
517518
num_bits, sawb_code, input_tensor, perCh=True
518519
)
@@ -551,7 +552,7 @@ def forward(
551552
else:
552553
output = torch.quantize_per_channel(
553554
input_tensor, scale, zero_point, 0, torch.qint8
554-
).int_repr()
555+
).int_repr().clamp(int_l, int_u)
555556
# NOTE return will be a torch.int8 tensor
556557

557558
return output

0 commit comments

Comments
 (0)