Skip to content

Commit 7a5e504

Browse files
committed
Bug fix for conversion
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent dded2c1 commit 7a5e504

File tree

1 file changed

+21
-19
lines changed

1 file changed

+21
-19
lines changed

fms_mo/utils/aiu_utils.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def process_smoothquant(
9898

9999
weight_scaled = None
100100
sq_a_scale = None
101-
w = new_sd[layer_name + ".weight"]
101+
w = model.state_dict()[layer_name + ".weight"]
102102
if layer_name + ".smoothq_alpha" in model.state_dict():
103103
sq_a_scale = model.state_dict()[layer_name + ".smoothq_act_scale"]
104104
if sum(sq_a_scale) != 0:
@@ -133,8 +133,6 @@ def recompute_weight_with_sawb(
133133
is_w_recomputed = False
134134
weight_int_sawb: torch.Tensor | None = None
135135
weight_int_std: torch.Tensor | float | None = None
136-
k = layer_name + ".weight"
137-
w = new_sd[k]
138136
if weight_per_channel:
139137
# recompute if any channel shows narrow int weights
140138
weight_int_std = weight_int_as_fp.std(dim=-1)
@@ -159,12 +157,12 @@ def recompute_weight_with_sawb(
159157
num_bits=8,
160158
dequantize=False,
161159
align_zero=True,
162-
perCh=w.size(0) if weight_per_channel else False,
160+
perCh=weight_int_as_fp.size(0) if weight_per_channel else False,
163161
)
164162
quantizer.training = True # set SAWB to recompute clips
165163
# some SAWB quantizers only process FP32 inputs, so weights are
166164
# temporarily upscaled
167-
weight_int_sawb = quantizer(w.to(torch.float32))
165+
weight_int_sawb = quantizer(weight_int_as_fp.to(torch.float32))
168166

169167
# 2. Recompute clip values using new SAWB quantizer
170168
w_cv_key = layer_name + ".quantize_weight.clip_val"
@@ -184,24 +182,27 @@ def recompute_weight_with_sawb(
184182
weight_int_sawb_as_fp = deepcopy(weight_int_sawb).to(torch.float32)
185183
if weight_per_channel:
186184
weight_int_sawb_std_min = weight_int_sawb_as_fp.std(dim=-1)[0].min()
187-
logger.debug(
188-
" Reprocessed weights "
189-
f"(std_min={weight_int_std_min:.1f} "
190-
f"-> {weight_int_sawb_std_min:.1f}) "
191-
f"and clips of {k}"
185+
if verbose:
186+
logger.info(
187+
" Reprocessed weights "
188+
f"(std_min={weight_int_std_min:.1f} "
189+
f"-> {weight_int_sawb_std_min:.1f}) "
190+
f"and clips of {layer_name + '.weight'}"
192191
)
193192
else:
194193
weight_int_sawb_as_fp_std = weight_int_sawb_as_fp.std()
195-
logger.debug(
196-
" Reprocessed weights "
197-
f"(std={weight_int_std:.1f} "
198-
f"-> {weight_int_sawb_as_fp_std:.1f}) "
199-
f"and clips of {k}"
200-
)
194+
if verbose:
195+
logger.info(
196+
" Reprocessed weights "
197+
f"(std={weight_int_std:.1f} "
198+
f"-> {weight_int_sawb_as_fp_std:.1f}) "
199+
f"and clips of {layer_name + '.weight'}"
200+
)
201201
else:
202202
log_min_std = "min_" if weight_per_channel else ""
203203
log_w_std = weight_int_std_min if weight_per_channel else weight_int_std
204-
logger.debug(f" Weights preserved ({log_min_std}std={log_w_std:.1f})")
204+
if verbose:
205+
logger.info(f" Weights preserved ({log_min_std}std={log_w_std:.1f})")
205206

206207
return weight_int_sawb, is_w_recomputed
207208

@@ -393,7 +394,7 @@ def convert_sd_for_aiu(
393394

394395
if recompute_narrow_weights:
395396
logger.info(
396-
f"Recomputed {num_w_recomputed} weights with SAWB, "
397+
f"Recomputed {num_w_recomputed} weight matrices with SAWB, "
397398
f"{num_w_preserved} preserved."
398399
)
399400

@@ -415,7 +416,7 @@ def save_sd_for_aiu(
415416
qcfg.get("recompute_narrow_weights", False) if qcfg is not None else False
416417
),
417418
weight_per_channel=(
418-
"perch" in qcfg.get("qw_mode", False) if qcfg is not None else False
419+
"perch" in qcfg.get("qw_mode", False).lower() if qcfg is not None else False
419420
),
420421
verbose=verbose,
421422
)
@@ -452,5 +453,6 @@ def save_for_aiu(
452453
"scale_layers",
453454
"qskip_layer_name",
454455
"qskip_large_mag_layers",
456+
"recompute_narrow_weights",
455457
]
456458
qconfig_save(qcfg, recipe=recipe, minimal=True, fname=Path(output_dir) / cfg_name)

0 commit comments

Comments
 (0)