Skip to content

Commit dc803bb

Browse files
committed
Add int8 sd conversion function for aiu
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent 2c75fcc commit dc803bb

File tree

1 file changed

+228
-0
lines changed

1 file changed

+228
-0
lines changed

fms_mo/utils/aiu_utils.py

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Standard
16+
from pathlib import Path
17+
import logging
18+
19+
# Third Party
20+
from transformers.modeling_utils import PreTrainedModel
21+
import torch
22+
23+
# logging is only enabled for verbose output (performance is less critical during debug),
24+
# and f-string style logging is preferred for code readability
25+
# pylint: disable=logging-not-lazy
26+
27+
28+
logger = logging.getLogger()
29+
30+
31+
def get_quantized_linear_names(model_type: str) -> list[str]:
32+
"""Return a list of unique identifiers for the linear layers in a given model."""
33+
34+
if model_type in ["granite", "llama"]:
35+
return [
36+
"self_attn.q_proj",
37+
"self_attn.k_proj",
38+
"self_attn.v_proj",
39+
"self_attn.o_proj",
40+
"mlp.gate_proj",
41+
"mlp.up_proj",
42+
"mlp.down_proj",
43+
]
44+
if model_type == "gpt_bigcode":
45+
return [
46+
"attn.c_attn",
47+
"attn.c_proj",
48+
"mlp.c_fc",
49+
"mlp.c_proj",
50+
]
51+
if model_type in ["bert", "roberta"]:
52+
return [
53+
"attention.self.query",
54+
"attention.self.key",
55+
"attention.self.value",
56+
"attention.output.dense",
57+
"intermediate.dense",
58+
"output.dense",
59+
]
60+
raise NotImplementedError(
61+
f"Model type {model_type} is not supported for quantized checkpoint saving"
62+
)
63+
64+
65+
def convert_sd_for_aiu(
66+
model: PreTrainedModel,
67+
verbose: bool,
68+
) -> dict[str, torch.Tensor]:
69+
"""Convert the state dictionary (sd) of an FMS-MO-quantized model into a format
70+
compatible with the AIU.
71+
72+
Expected tensors in input state dictionary:
73+
weights:
74+
[out_feat, in_feat]
75+
w_cv:
76+
perT [1]
77+
perCh [out_feat]
78+
w_cvn = - w_cv <--- always symmetric!
79+
a_cv:
80+
per-token-max n/a
81+
perT [1]
82+
a_cvn: symmetric or asymmetric
83+
84+
Smoothquant combined scale is computed as:
85+
s_sq = a_sq_scale ^ alpha / w_sq_scale ^ (1- alpha)
86+
87+
All parameters except quantized weights are cast to FP16, per AIU requirement.
88+
"""
89+
90+
if verbose:
91+
logger.info("Before conversion:")
92+
logger.info("* ALL MODEL PARAMETERS (name, size, dtype)")
93+
logger.info(
94+
"\n"
95+
+ "\n".join(
96+
f"{k:80} {str(list(v.size())):15} {v.dtype}"
97+
for k, v in model.named_parameters()
98+
)
99+
)
100+
logger.info("* ALL BUFFERS (name, size, dtype)")
101+
logger.info(
102+
"\n"
103+
+ "\n".join(
104+
f"{k:80} {str(list(v.size())):15} {v.dtype}"
105+
for k, v in model.named_buffers()
106+
)
107+
)
108+
logger.info("=" * 60)
109+
110+
model_type = getattr(model.config, "model_type", None)
111+
if model_type:
112+
quantized_layers = get_quantized_linear_names(model_type)
113+
else:
114+
raise ValueError(
115+
"Could not determine model type to save quantized state dictionary."
116+
)
117+
excluded_keys_from_new_sd = [
118+
"calib_counter",
119+
"num_module_called",
120+
"smoothq_act_scale",
121+
"smoothq_alpha",
122+
"obsrv_w_clipval",
123+
"obsrv_clipval",
124+
"obsrv_clipvaln",
125+
]
126+
127+
new_sd = {}
128+
for k, v in model.state_dict().items():
129+
if k.endswith(".weight") and any(qlayer in k for qlayer in quantized_layers):
130+
layername = k[:-7]
131+
132+
# smoothquant processing:
133+
# - if smoothquant wasn't used, smoothq_alpha doesn't exist or is zero
134+
# - compute combined weight/activation smoothquant scaling factor (sq_scale)
135+
# - rescale weights before quantization
136+
# - store scaling factor into state dict
137+
v_scaled = None
138+
if layername + ".smoothq_alpha" in model.state_dict():
139+
sq_a_scale = model.state_dict()[layername + ".smoothq_act_scale"]
140+
if sum(sq_a_scale) != 0:
141+
sq_alpha = model.state_dict()[layername + ".smoothq_alpha"]
142+
sq_w_scale = v.abs().max(dim=0, keepdim=True).values.clamp(min=1e-5)
143+
sq_scale = sq_a_scale.pow(sq_alpha) / sq_w_scale.pow(1 - sq_alpha)
144+
v_scaled = v * sq_scale # weights sq-scaled before quantization
145+
# guarding FP16 casting
146+
if sq_scale.abs().max() > torch.finfo(torch.float16).max:
147+
raise ValueError(
148+
"Quantization parameters (qscale) exceeds float16 range. "
149+
"Aborted state dict saving."
150+
)
151+
new_sd[layername + ".smoothq_scale"] = (
152+
sq_scale.squeeze().to(torch.float16).to("cpu")
153+
)
154+
155+
# quantize weights and store them into state dict
156+
if layername + ".quantize_weight.clip_val" in model.state_dict():
157+
w_cv = model.state_dict()[layername + ".quantize_weight.clip_val"]
158+
if w_cv.numel() > 1:
159+
w_cv = w_cv.unsqueeze(dim=1)
160+
weight_pre_quant = v_scaled if v_scaled is not None else v
161+
weight_int = torch.clamp(
162+
127 / w_cv * weight_pre_quant, -127, 127
163+
).round()
164+
new_sd[k] = weight_int.to(torch.int8).to("cpu") # signed int8
165+
166+
a_cv_name = layername + ".quantize_feature.clip_val"
167+
a_cvn_name = a_cv_name + "n"
168+
a_cv = None
169+
a_cvn = None
170+
if a_cv_name in model.state_dict():
171+
a_cv = model.state_dict()[a_cv_name]
172+
if a_cvn_name in model.state_dict():
173+
a_cvn = model.state_dict()[a_cvn_name]
174+
175+
# compute "zero_shift" correction factor only for asymmetric activations
176+
if a_cv and a_cvn and a_cv != -a_cvn:
177+
if v.dim() == 2:
178+
# weight_int: [out_feat, in_feat]
179+
# sum (squash) along in_feat dimension: dim=1
180+
new_sd[layername + ".zero_shift"] = (
181+
torch.sum(
182+
weight_int,
183+
dim=1,
184+
)
185+
.to(torch.float16)
186+
.to("cpu")
187+
)
188+
else:
189+
raise NotImplementedError(
190+
"Zero shift computation for tensor "
191+
"with more than 2 dims is not supported yet."
192+
)
193+
elif all(excluded_key not in k for excluded_key in excluded_keys_from_new_sd):
194+
# guarding FP16 cast
195+
if v.abs().max() > torch.finfo(torch.float16).max:
196+
raise ValueError(
197+
f"Quantization parameters ({k}) exceeds float16 range. "
198+
"Aborted state dict saving."
199+
)
200+
new_sd[k] = v.to("cpu").to(torch.float16)
201+
202+
logger.info("New state dict processed.")
203+
if verbose:
204+
logger.info(
205+
"\n"
206+
+ "\n".join(
207+
f"{k:80} {str(list(v.size())):15} "
208+
f"{str(v.dtype):18} {str(v.device):10} "
209+
f"{v.reshape(-1)[0].item():12.4f} "
210+
f"{v.min().item():12.4f} {v.max().item():12.4f}"
211+
for k, v in new_sd.items()
212+
)
213+
)
214+
215+
return new_sd
216+
217+
218+
def save_sd_for_aiu(
219+
model: PreTrainedModel,
220+
output_dir: str,
221+
savename: str = "qmodel_state_dict.pt",
222+
verbose: bool = False,
223+
) -> None:
224+
"""Save model state dictionary after conversion for AIU compatibility."""
225+
226+
converted_sd = convert_sd_for_aiu(model, verbose)
227+
torch.save(converted_sd, Path(output_dir) / savename)
228+
logger.info("Model saved.")

0 commit comments

Comments
 (0)