Skip to content

Commit ea8540b

Browse files
authored
support ptpc-int8 (#225)
1 parent 4e6a606 commit ea8540b

File tree

1 file changed

+58
-23
lines changed

1 file changed

+58
-23
lines changed
Lines changed: 58 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,33 @@
2626

2727
from angelslim.compressor.quant.core.quant_func import weight_dequant
2828

29+
SUFFIX_TO_QUANT = [
30+
".gate_and_up_proj.weight",
31+
".gate_proj.weight",
32+
".up_proj.weight",
33+
".down_proj.weight",
34+
".q_a_proj.weight",
35+
".q_b_proj.weight",
36+
".kv_a_proj_with_mqa.weight",
37+
".kv_b_proj.weight",
38+
".qkv_proj.weight",
39+
".q_proj.weight",
40+
".k_proj.weight",
41+
".v_proj.weight",
42+
".o_proj.weight",
43+
".indexer.wq_b.weight",
44+
".indexer.wk.weight",
45+
]
46+
2947

3048
def process_worker(
31-
worker_id, safetensor_files, fp8_path, int8_path, weight_map, return_dict
49+
worker_id,
50+
safetensor_files,
51+
input_path,
52+
int8_path,
53+
weight_map,
54+
return_dict,
55+
input_type="bf16",
3256
):
3357
"""
3458
Process worker.
@@ -51,18 +75,19 @@ def process_worker(
5175
keys = set(f.keys())
5276
for weight_name in keys:
5377
weight = f.get_tensor(weight_name)
54-
scale_inv_name = f"{weight_name}_scale_inv"
55-
if scale_inv_name in weight_map:
78+
if any(weight_name.endswith(suffix) for suffix in SUFFIX_TO_QUANT):
5679
quant_count += 1
57-
# 1. fp8 dequant to bf16
58-
scale_inv = get_tensor_from_file(
59-
rank, scale_inv_name, weight_map, fp8_path
60-
)
61-
weight_bf16 = weight_dequant(weight, scale_inv)
62-
# 2. bf16 quant to int8
80+
if input_type == "fp8":
81+
scale_inv_name = f"{weight_name}_scale_inv"
82+
scale_inv = get_tensor_from_file(
83+
rank, scale_inv_name, weight_map, input_path
84+
)
85+
weight_bf16 = weight_dequant(weight, scale_inv)
86+
else:
87+
weight_bf16 = weight
6388
int8_weight, scale_inv = weight_quant(weight_bf16)
6489
new_state_dict[weight_name] = int8_weight
65-
new_scale_name = scale_inv_name.replace("_scale_inv", "_scale")
90+
new_scale_name = f"{weight_name}_scale"
6691
new_state_dict[new_scale_name] = scale_inv
6792
new_weight_map[weight_name] = file_name
6893
new_weight_map[new_scale_name] = file_name
@@ -78,7 +103,7 @@ def process_worker(
78103

79104

80105
# Helper function to get tensor from the correct file
81-
def get_tensor_from_file(rank, tensor_name, weight_map, fp8_path):
106+
def get_tensor_from_file(rank, tensor_name, weight_map, input_path):
82107
"""
83108
Retrieves a tensor from mmap safe_tensors
84109
@@ -93,7 +118,7 @@ def get_tensor_from_file(rank, tensor_name, weight_map, fp8_path):
93118
"""
94119
torch.cuda.set_device(rank)
95120
file_name = weight_map[tensor_name]
96-
file_path = os.path.join(fp8_path, file_name)
121+
file_path = os.path.join(input_path, file_name)
97122

98123
with safe_open(file_path, framework="pt", device=f"cuda:{rank}") as f:
99124
return f.get_tensor(tensor_name)
@@ -119,7 +144,7 @@ def weight_quant(tensor: torch.Tensor):
119144
return quantized.to(torch.int8), scale.to(torch.float32)
120145

121146

122-
def main(fp8_path, int8_path, num_workers):
147+
def main(input_path, int8_path, num_workers):
123148
"""
124149
Run the FP8-to-INT8 per-channel quantization pipeline.
125150
@@ -130,7 +155,7 @@ def main(fp8_path, int8_path, num_workers):
130155
4. Saves quantized safetensors and updates model index.
131156
132157
Args:
133-
fp8_path (str): Path to directory containing FP8 safetensors.
158+
input_path (str): Path to directory containing FP8 safetensors.
134159
int8_path (str): Output directory to save INT8 safetensors.
135160
num_workers (int): Number of processing workers
136161
"""
@@ -139,10 +164,10 @@ def main(fp8_path, int8_path, num_workers):
139164
model_index_file = os.path.join(int8_path, "model.safetensors.index.json")
140165
config_file = os.path.join(int8_path, "config.json")
141166

142-
for fname in os.listdir(fp8_path):
167+
for fname in os.listdir(input_path):
143168
if fname.endswith(".safetensors"):
144169
continue
145-
src = os.path.join(fp8_path, fname)
170+
src = os.path.join(input_path, fname)
146171
dst = os.path.join(int8_path, fname)
147172
if os.path.isdir(src):
148173
print(f"cp -r {src} {dst}")
@@ -154,7 +179,11 @@ def main(fp8_path, int8_path, num_workers):
154179
# modify config.json and save it
155180
config = json.load(open(config_file))
156181
# delete quantization_config
157-
config.pop("quantization_config", None)
182+
quant_config = config.pop("quantization_config", None)
183+
input_type = "bf16"
184+
if quant_config is not None:
185+
input_type = quant_config.get("quant_method", input_type)
186+
print("input_type", input_type)
158187
config["quantization_config"] = {
159188
"config_groups": {
160189
"group_0": {
@@ -200,9 +229,8 @@ def main(fp8_path, int8_path, num_workers):
200229
with open(model_index_file, "r") as f:
201230
model_index = json.load(f)
202231
weight_map = model_index["weight_map"]
203-
scale_count = len([key for key in weight_map.keys() if key.endswith("_scale_inv")])
204232

205-
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
233+
safetensor_files = list(glob(os.path.join(input_path, "*.safetensors")))
206234
safetensor_files.sort()
207235
quant_count = 0
208236
new_weight_map = {}
@@ -216,7 +244,15 @@ def main(fp8_path, int8_path, num_workers):
216244
for i in range(num_workers):
217245
p = mp.Process(
218246
target=process_worker,
219-
args=(i, file_subsets[i], fp8_path, int8_path, weight_map, return_dict),
247+
args=(
248+
i,
249+
file_subsets[i],
250+
input_path,
251+
int8_path,
252+
weight_map,
253+
return_dict,
254+
input_type,
255+
),
220256
)
221257
p.start()
222258
processes.append(p)
@@ -227,7 +263,6 @@ def main(fp8_path, int8_path, num_workers):
227263
qc, wm = return_dict[i]
228264
quant_count += qc
229265
new_weight_map.update(wm)
230-
assert quant_count == scale_count
231266
print(f"{quant_count} weights are quantized.")
232267

233268
# modify model.safetensors.index.json
@@ -241,10 +276,10 @@ def main(fp8_path, int8_path, num_workers):
241276

242277
if __name__ == "__main__":
243278
parser = ArgumentParser()
244-
parser.add_argument("--input-fp8-path", type=str, required=True)
279+
parser.add_argument("--input-path", type=str, required=True)
245280
parser.add_argument("--output-int8-path", type=str, required=True)
246281
parser.add_argument("--num-workers", type=int, default=32)
247282

248283
args = parser.parse_args()
249-
main(args.input_fp8_path, args.output_int8_path, args.num_workers)
284+
main(args.input_path, args.output_int8_path, args.num_workers)
250285
print("done")

0 commit comments

Comments
 (0)