1
+ import argparse
2
+ import re
3
+ from typing import Tuple
4
+
1
5
import torch
2
6
import torch .functional as F
3
- from typing import Tuple
4
7
import transformers
5
- from transformers import AutoModelForCausalLM , AutoTokenizer
6
8
from datasets import load_dataset
7
- import re
8
-
9
- MODEL_ID = "facebook/opt-125m"
10
- # MODEL_ID = "echarlaix/tiny-random-mistral"
11
-
12
-
13
- NUM_PROMPTS = 512
14
- MAX_SEQ_LEN = 512
9
+ from transformers import AutoModelForCausalLM , AutoTokenizer
15
10
16
11
17
12
# HACK: override the dtype_byte_size function in transformers to support float8 types
@@ -208,15 +203,23 @@ def quantize_activations(model, calibration_tokens):
208
203
209
204
210
205
if __name__ == "__main__" :
211
- tokenizer = AutoTokenizer .from_pretrained (MODEL_ID )
206
+ parser = argparse .ArgumentParser ()
207
+ parser .add_argument ("--model-id" , type = str )
208
+ parser .add_argument ("--save-dir" , type = str )
209
+ # parser.add_argument("--static-act", action="store_true")
210
+ parser .add_argument ("--num-samples" , type = int , default = 512 )
211
+ parser .add_argument ("--max-seq-len" , type = int , default = 512 )
212
+ args = parser .parse_args ()
213
+
214
+ tokenizer = AutoTokenizer .from_pretrained (args .model_id )
212
215
sample_input_tokens = tokenizer .apply_chat_template (
213
216
[{"role" : "user" , "content" : "What is your name?" }],
214
217
add_generation_prompt = True ,
215
218
return_tensors = "pt" ,
216
219
).to ("cuda" )
217
220
218
221
ds = load_dataset ("HuggingFaceH4/ultrachat_200k" , split = "train_sft" )
219
- ds = ds .shuffle (seed = 42 ).select (range (NUM_PROMPTS ))
222
+ ds = ds .shuffle (seed = 42 ).select (range (args . num_samples ))
220
223
ds = ds .map (
221
224
lambda batch : {
222
225
"text" : tokenizer .apply_chat_template (batch ["messages" ], tokenize = False )
@@ -228,14 +231,14 @@ def quantize_activations(model, calibration_tokens):
228
231
return_tensors = "pt" ,
229
232
truncation = True ,
230
233
padding = "max_length" ,
231
- max_length = MAX_SEQ_LEN ,
234
+ max_length = args . max_seq_len ,
232
235
add_special_tokens = False ,
233
236
).input_ids .to ("cuda" )
234
237
print ("Calibration tokens:" , calibration_tokens .shape )
235
238
236
239
# Load and test the model
237
240
model = AutoModelForCausalLM .from_pretrained (
238
- MODEL_ID , torch_dtype = "auto" , device_map = "auto"
241
+ args . model_id , torch_dtype = "auto" , device_map = "auto"
239
242
)
240
243
output = model .generate (input_ids = sample_input_tokens , max_new_tokens = 20 )
241
244
print ("ORIGINAL:\n " , tokenizer .decode (output [0 ]), "\n \n " )
@@ -251,9 +254,8 @@ def quantize_activations(model, calibration_tokens):
251
254
print ("ACT QUANT:\n " , tokenizer .decode (output [0 ]), "\n \n " )
252
255
253
256
# Save the model fully quantized
254
- output_path = "fp8-static-quant"
255
- print (f"Saving the model to { output_path } " )
257
+ print (f"Saving the model to { args .save_dir } " )
256
258
static_q_dict = {"quantization_config" : {"quant_method" : "fp8" , "scheme" : "static" }}
257
259
model .config .update (static_q_dict )
258
- model .save_pretrained (output_path )
259
- tokenizer .save_pretrained (output_path )
260
+ model .save_pretrained (args . save_dir )
261
+ tokenizer .save_pretrained (args . save_dir )
0 commit comments