Skip to content

Commit 9d7f152

Browse files
Merge pull request #26 from andrea-fasoli/clip_symmetry
Fix symmetric behavior (issue #22)
2 parents 6b702d1 + 18f23fb commit 9d7f152

File tree

3 files changed

+38
-7
lines changed

3 files changed

+38
-7
lines changed

fms_mo/calib.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,17 +72,35 @@ def __call__(self, module, inputs: torch.Tensor):
7272
with torch.no_grad():
7373
x = inputs[0].detach() # TODO: still need detach() under no_grad context?
7474

75+
symmetric = False
76+
if module.quantize_feature:
77+
# default to asymmetric clip_val computation
78+
# TODO: this misses symmetry of PACTPlusSym, PACT2Sym, and QFixSymmetric
79+
symmetric = not getattr(module.quantize_feature, "minmax", True)
80+
7581
nelem = x.nelement()
7682
if self.a_init_method == "percentile":
7783
lower_k = int(self.per[0] * nelem)
78-
lower_per_cur = (
84+
lower_per_cur_candidate = (
7985
x.reshape(1, -1).kthvalue(lower_k).values.data[0]
8086
if lower_k > 0
8187
else x.min()
8288
) # guard rail: tensors with very few elements could cause kthvalue(0) error
83-
upper_per_cur = (
89+
upper_per_cur_candidate = (
8490
x.reshape(1, -1).kthvalue(int(self.per[1] * nelem)).values.data[0]
8591
)
92+
if symmetric:
93+
upper_per_cur = max(
94+
upper_per_cur_candidate,
95+
lower_per_cur_candidate.abs(),
96+
)
97+
lower_per_cur = -upper_per_cur
98+
else:
99+
upper_per_cur = upper_per_cur_candidate
100+
lower_per_cur = lower_per_cur_candidate
101+
elif symmetric:
102+
upper_per_cur = x.abs().max()
103+
lower_per_cur = -upper_per_cur
86104
else:
87105
lower_per_cur = x.min()
88106
upper_per_cur = x.max()

fms_mo/modules/linear.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,12 +162,17 @@ def __init__(
162162
use_subnormal=self.fp8_use_subnormal,
163163
)
164164
if self.calib_counter > 0:
165+
qa_mode_calib = (
166+
self.qa_mode_calib + "sym"
167+
if self.qa_mode.endswith("sym")
168+
else self.qa_mode_calib
169+
)
165170
self.quantize_calib_feature = Qdynamic(
166171
self.num_bits_feature,
167172
qcfg,
168173
non_neg=self.non_neg,
169174
align_zero=self.align_zero,
170-
qmode=self.qa_mode_calib,
175+
qmode=qa_mode_calib,
171176
quantizer2sync=self.quantize_feature,
172177
)
173178

fms_mo/quant/quantizers.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3512,7 +3512,7 @@ def __init__(
35123512
"""
35133513
super().__init__()
35143514
self.num_bits = num_bits
3515-
self.symmetric = symmetric or qmode.endswith("_sym")
3515+
self.symmetric = symmetric or qmode.endswith("sym")
35163516
self.nlevels = (
35173517
2**self.num_bits - 2 if self.symmetric else 2**self.num_bits - 1
35183518
)
@@ -3553,24 +3553,32 @@ def forward(self, input_tensor):
35533553
with torch.no_grad():
35543554
if self.qmode.startswith("percentile"):
35553555
nelem = input_tensor.nelement()
3556-
cv_new = (
3556+
cv_new_candidate = (
35573557
input_tensor.reshape(1, -1)
35583558
.float()
35593559
.kthvalue(
35603560
round(self.per[1] * 0.01 * nelem)
35613561
) # built-in 'round' returns int
35623562
.values.data[0]
35633563
).to(input_tensor.dtype)
3564+
35643565
# conventionaly percentile is input_tensor as 99.9 (% is implied),
35653566
# so we need *0.01 here
35663567
lower_k = round(self.per[0] * 0.01 * nelem)
3567-
cvn_new = (
3568+
cvn_new_candidate = (
35683569
input_tensor.reshape(1, -1).float().kthvalue(lower_k).values.data[0]
35693570
if lower_k > 0
35703571
else input_tensor.min()
35713572
).to(
35723573
input_tensor.dtype
35733574
) # for very small tensor, lower_k could be 0, kthvalue(0) will cause error
3575+
3576+
if self.symmetric:
3577+
cv_new = max(cv_new_candidate, cvn_new_candidate.abs())
3578+
cvn_new = -cv_new
3579+
else:
3580+
cv_new = cv_new_candidate
3581+
cvn_new = cvn_new_candidate
35743582
elif (
35753583
self.qmode == "sawb" and self.num_bits == 4
35763584
): # only works for PACT+sym for weights
@@ -3579,7 +3587,7 @@ def forward(self, input_tensor):
35793587

35803588
else: # i.e., minmax
35813589
cv_new = input_tensor.max()
3582-
cvn_new = input_tensor.min()
3590+
cvn_new = -cv_new if self.symmetric else input_tensor.min()
35833591

35843592
if self.Niter == 0 and self.training:
35853593
# to avoid unintended bwd ops added to the graph, cause memory leak sometimes

0 commit comments

Comments
 (0)