Skip to content

Commit 049df45

Browse files
fix 2 bugs in aiu_save funcs related to vector clipvals and zero_shift being fp32
Signed-off-by: cliu-us <[email protected]>
1 parent 7417662 commit 049df45

File tree

3 files changed

+15
-9
lines changed

3 files changed

+15
-9
lines changed

fms_mo/calib.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -574,13 +574,19 @@ def qmodel_calib(
574574
f"Qmodel calibration (clip_val analysis) in progress: {i}/{Nbatch}"
575575
)
576576

577-
if "perCh" not in qcfg["qw_mode"]:
578-
cv_sum_dict = {"layer": [], "value": []}
579-
for k, v in tempmodel.state_dict().items():
580-
if "clip" in k:
581-
cv_sum_dict["layer"].append(k)
582-
cv_sum_dict["value"].append(v.item())
583-
logger.info(f"Observed clipvals: \n{ pd.DataFrame(cv_sum_dict) }")
577+
cv_sum_dict = {"layer": [], "value": []}
578+
for k, v in tempmodel.state_dict().items():
579+
if "clip" not in k:
580+
continue
581+
582+
if v.numel() > 1:
583+
k = k + "*"
584+
v = v.mean()
585+
cv_sum_dict["layer"].append(k)
586+
cv_sum_dict["value"].append(v.item())
587+
logger.info(
588+
f"Observed clipvals: ('*' if it's a vector) \n{ pd.DataFrame(cv_sum_dict) }"
589+
)
584590

585591
# Step 3: extract new clip_vals, params and buffers, then remove handles if needed
586592
temp_new_clipvals = {

fms_mo/utils/aiu_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def process_zero_shift(
288288
a_cvn = model.state_dict()[a_cvn_name]
289289

290290
# compute "zero_shift" correction factor only for asymmetric activations
291-
if a_cv and a_cvn and a_cv != -a_cvn:
291+
if a_cv is not None and a_cvn is not None and torch.equal(a_cv, -a_cvn):
292292
if weight_int is None:
293293
logger.info(
294294
f"As weights appear to be not quantized, zero shift for {k} "

tests/models/test_model_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def check_linear_dtypes(state_dict: dict, linear_names: list):
232232
if any(n in k for n in linear_names):
233233
if k.endswith(".weight"):
234234
assert v.dtype == torch.int8
235-
elif k.endswith(".zero_point"):
235+
elif k.endswith(".zero_point") or k.endswith(".zero_shift"):
236236
assert v.dtype == torch.float32
237237
else:
238238
assert v.dtype == torch.float16

0 commit comments

Comments
 (0)