Skip to content

Commit b0c1f78

Browse files
committed
Update with CLI
1 parent 81c33df commit b0c1f78

File tree

1 file changed

+20
-18
lines changed

1 file changed

+20
-18
lines changed

quantize.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,12 @@
1+
import argparse
2+
import re
3+
from typing import Tuple
4+
15
import torch
26
import torch.functional as F
3-
from typing import Tuple
47
import transformers
5-
from transformers import AutoModelForCausalLM, AutoTokenizer
68
from datasets import load_dataset
7-
import re
8-
9-
MODEL_ID = "facebook/opt-125m"
10-
# MODEL_ID = "echarlaix/tiny-random-mistral"
11-
12-
13-
NUM_PROMPTS = 512
14-
MAX_SEQ_LEN = 512
9+
from transformers import AutoModelForCausalLM, AutoTokenizer
1510

1611

1712
# HACK: override the dtype_byte_size function in transformers to support float8 types
@@ -208,15 +203,23 @@ def quantize_activations(model, calibration_tokens):
208203

209204

210205
if __name__ == "__main__":
211-
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
206+
parser = argparse.ArgumentParser()
207+
parser.add_argument("--model-id", type=str)
208+
parser.add_argument("--save-dir", type=str)
209+
# parser.add_argument("--static-act", action="store_true")
210+
parser.add_argument("--num-samples", type=int, default=512)
211+
parser.add_argument("--max-seq-len", type=int, default=512)
212+
args = parser.parse_args()
213+
214+
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
212215
sample_input_tokens = tokenizer.apply_chat_template(
213216
[{"role": "user", "content": "What is your name?"}],
214217
add_generation_prompt=True,
215218
return_tensors="pt",
216219
).to("cuda")
217220

218221
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
219-
ds = ds.shuffle(seed=42).select(range(NUM_PROMPTS))
222+
ds = ds.shuffle(seed=42).select(range(args.num_samples))
220223
ds = ds.map(
221224
lambda batch: {
222225
"text": tokenizer.apply_chat_template(batch["messages"], tokenize=False)
@@ -228,14 +231,14 @@ def quantize_activations(model, calibration_tokens):
228231
return_tensors="pt",
229232
truncation=True,
230233
padding="max_length",
231-
max_length=MAX_SEQ_LEN,
234+
max_length=args.max_seq_len,
232235
add_special_tokens=False,
233236
).input_ids.to("cuda")
234237
print("Calibration tokens:", calibration_tokens.shape)
235238

236239
# Load and test the model
237240
model = AutoModelForCausalLM.from_pretrained(
238-
MODEL_ID, torch_dtype="auto", device_map="auto"
241+
args.model_id, torch_dtype="auto", device_map="auto"
239242
)
240243
output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20)
241244
print("ORIGINAL:\n", tokenizer.decode(output[0]), "\n\n")
@@ -251,9 +254,8 @@ def quantize_activations(model, calibration_tokens):
251254
print("ACT QUANT:\n", tokenizer.decode(output[0]), "\n\n")
252255

253256
# Save the model fully quantized
254-
output_path = "fp8-static-quant"
255-
print(f"Saving the model to {output_path}")
257+
print(f"Saving the model to {args.save_dir}")
256258
static_q_dict = {"quantization_config": {"quant_method": "fp8", "scheme": "static"}}
257259
model.config.update(static_q_dict)
258-
model.save_pretrained(output_path)
259-
tokenizer.save_pretrained(output_path)
260+
model.save_pretrained(args.save_dir)
261+
tokenizer.save_pretrained(args.save_dir)

0 commit comments

Comments
 (0)