Skip to content

Commit f858ffa

Browse files
committed
Add FMS addon for INT8xINT8 support
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent a910ae6 commit f858ffa

File tree

4 files changed

+604
-0
lines changed

4 files changed

+604
-0
lines changed

fms_mo/aiu_addons/int8/__init__.py

Whitespace-only changes.
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Implement FMS adapter for INT8xINT8 checkpoints"""
15+
16+
# Standard
17+
from typing import Mapping
18+
19+
# Third Party
20+
from fms.utils import serialization
21+
import torch
22+
23+
24+
def _int8_qparams_aiu(
25+
input_sd: Mapping[str, torch.Tensor],
26+
) -> Mapping[str, torch.Tensor]:
27+
new_sd = {}
28+
modules_seen = set()
29+
for name, param in input_sd.items():
30+
new_name = name
31+
if "clip_val" in name:
32+
name_split = name.split(".")
33+
is_weight = "weight" in name_split[-2]
34+
module_name = ".".join(name_split[:-2])
35+
modules_seen.add(module_name)
36+
37+
param_type = "w" if is_weight else "a"
38+
new_name = f"{module_name}.{param_type}_{name_split[-1]}"
39+
elif "smoothq" in name:
40+
new_name = name.replace("smoothq", "smoothquant")
41+
42+
new_sd[new_name] = param
43+
44+
_add_defaults_and_concat(new_sd, modules_seen)
45+
return new_sd
46+
47+
48+
def _add_defaults_and_concat(
49+
new_sd: Mapping[str, torch.Tensor],
50+
modules_seen: set,
51+
) -> None:
52+
"""
53+
Add default activation clip values, zero_shift, and smoothquant_scale (if not
54+
already present) to every linear module processed in the partial state dict.
55+
It is assumed that weight clip values are always present and don't need default.
56+
57+
For every module, also create float32 `qdata` tensor, as concatenation of
58+
quantization metadata tensors, as per AIU requirement.
59+
"""
60+
61+
for module_name in modules_seen:
62+
# add default activation clip values (both), if not present
63+
if module_name + ".a_clip_val" not in new_sd:
64+
a_clip_val = torch.zeros(1, dtype=torch.float16)
65+
a_clip_valn = torch.zeros(1, dtype=torch.float16)
66+
new_sd[module_name + ".a_clip_val"] = a_clip_val
67+
new_sd[module_name + ".a_clip_valn"] = a_clip_valn
68+
else:
69+
a_clip_val = new_sd[module_name + ".a_clip_val"]
70+
a_clip_valn = new_sd[module_name + ".a_clip_valn"]
71+
72+
# add default zero shift, if not present
73+
if module_name + ".zero_shift" not in new_sd:
74+
zero_shift = torch.zeros(1, dtype=torch.float32)
75+
new_sd[module_name + ".zero_shift"] = zero_shift
76+
else:
77+
zero_shift = new_sd[module_name + ".zero_shift"]
78+
79+
# add default smoothquant scale, if not present
80+
if module_name + ".smoothquant_scale" not in new_sd:
81+
sq_scale = torch.ones(1, dtype=torch.float16)
82+
new_sd[module_name + ".smoothquant_scale"] = sq_scale
83+
else:
84+
sq_scale = new_sd[module_name + ".smoothquant_scale"]
85+
86+
# add concatenated quantization metadata to state dict
87+
new_sd[module_name + ".qdata"] = torch.cat(
88+
(
89+
new_sd[module_name + ".w_clip_val"].to(torch.float32),
90+
new_sd[module_name + ".w_clip_valn"].to(torch.float32),
91+
a_clip_val.to(torch.float32),
92+
a_clip_valn.to(torch.float32),
93+
zero_shift.to(torch.float32), # should be already fp32
94+
sq_scale.to(torch.float32),
95+
)
96+
)
97+
return
98+
99+
100+
# registration of new adapter steps for each architecture
101+
serialization.register_adapter_step("llama", "int8_qparams_aiu", _int8_qparams_aiu)
102+
serialization.register_adapter_step(
103+
"gpt_bigcode", "int8_qparams_aiu", _int8_qparams_aiu
104+
)
105+
serialization.register_adapter_step("roberta", "int8_qparams_aiu", _int8_qparams_aiu)
106+
107+
# registration of multi-step adapter for each architecture
108+
serialization.register_adapter(
109+
"llama",
110+
"fms_mo",
111+
[
112+
"hf_to_fms_names",
113+
"hf_to_fms_rope",
114+
"weight_fusion",
115+
"int8_qparams_aiu",
116+
],
117+
)
118+
serialization.register_adapter(
119+
"gpt_bigcode", "fms_mo", ["hf_to_fms_names", "weight_fusion", "int8_qparams_aiu"]
120+
)
121+
serialization.register_adapter(
122+
"roberta", "fms_mo", ["hf_to_fms_names", "weight_fusion", "int8_qparams_aiu"]
123+
)

0 commit comments

Comments
 (0)