Skip to content

Commit b1edea7

Browse files
committed
fix observer zp dtype
1 parent d25e435 commit b1edea7

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

mqbench/observer.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,28 @@ def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
5454
scale = pot_quantization(scale)
5555
return scale, zero_point
5656

57+
@torch.jit.export
58+
def _calculate_qparams(
59+
self, min_val: torch.Tensor, max_val: torch.Tensor
60+
) -> Tuple[torch.Tensor, torch.Tensor]:
61+
r"""Calculates the quantization parameters, given min and max
62+
value tensors. Works for both per tensor and per channel cases
63+
64+
Args:
65+
min_val: Minimum values per channel
66+
max_val: Maximum values per channel
67+
68+
Returns:
69+
scales: Scales tensor of shape (#channels,)
70+
zero_points: Zero points tensor of shape (#channels,)
71+
"""
72+
scale, zero_point = super()._calculate_qparams(min_val, max_val)
73+
if _version_under_1100:
74+
zero_point = zero_point.long()
75+
else:
76+
zero_point = zero_point.int()
77+
return scale, zero_point
78+
5779
@torch.jit.export
5880
def _calculate_qmin_qmax(self) -> Tuple[int, int]:
5981
r"""Calculates actual qmin and qmax based on the quantization range,

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
torch==1.10.0
22
torchvision==0.11.1
3-
onnx==1.7.0
43
numpy==1.19.0
54
protobuf==3.20.3
65
prettytable
6+
onnx==1.13.1

0 commit comments

Comments
 (0)