9
9
10
10
import argparse
11
11
import copy
12
+ import json
12
13
import os
13
14
import timeit
14
15
from contextlib import nullcontext
21
22
from torchtrt_ext import register_sdpa
22
23
from transformers import AutoModelForCausalLM , AutoTokenizer
23
24
from utils import (
25
+ convert_linear_to_tensorrt_quantized ,
24
26
export_llm ,
25
27
generate ,
26
28
generate_with_static_cache ,
29
+ quantize_model ,
27
30
record_stats ,
28
31
time_generate ,
29
32
)
@@ -48,6 +51,7 @@ def get_model(args):
48
51
torch.nn.Module: The loaded and configured model ready for inference,
49
52
moved to CUDA device with the specified precision
50
53
"""
54
+
51
55
with torch .no_grad ():
52
56
model = (
53
57
AutoModelForCausalLM .from_pretrained (
@@ -58,6 +62,8 @@ def get_model(args):
58
62
.eval ()
59
63
.cuda ()
60
64
)
65
+ if args .pre_quantized :
66
+ model = convert_linear_to_tensorrt_quantized (model , args .model )
61
67
62
68
if args .precision == "FP16" :
63
69
model = model .to (torch .float16 )
@@ -106,7 +112,23 @@ def compile_torchtrt(model, input_ids, args):
106
112
else :
107
113
enabled_precisions = {torch .float32 }
108
114
109
- with torch_tensorrt .logging .debug () if args .debug else nullcontext ():
115
+ qformat = "_q_" + args .qformat if args .qformat else ""
116
+
117
+ logging_dir = f"./{ args .model } _{ args .precision } { qformat } "
118
+ # with torch_tensorrt.logging.debug() if args.debug else nullcontext():
119
+ with (
120
+ torch_tensorrt .dynamo .Debugger (
121
+ "debug" ,
122
+ logging_dir = logging_dir ,
123
+ # capture_fx_graph_after=["constant_fold"],
124
+ # save_engine_profile=True,
125
+ # profile_format="trex",
126
+ engine_builder_monitor = False ,
127
+ # save_layer_info=True,
128
+ )
129
+ if args .debug
130
+ else nullcontext ()
131
+ ):
110
132
trt_model = torch_tensorrt .dynamo .compile (
111
133
ep ,
112
134
inputs = [input_ids , position_ids ],
@@ -129,12 +151,14 @@ def print_outputs(backend_name, gen_tokens, tokenizer):
129
151
"""
130
152
Print the generated tokens from the model.
131
153
"""
154
+ out = tokenizer .decode (gen_tokens [0 ], skip_special_tokens = True )
132
155
print (f"========= { backend_name } =========" )
133
156
print (
134
157
f"{ backend_name } model generated text: " ,
135
- tokenizer . decode ( gen_tokens [ 0 ], skip_special_tokens = True ) ,
158
+ out ,
136
159
)
137
160
print ("===================================" )
161
+ return out
138
162
139
163
140
164
def measure_perf (trt_model , input_signature , backend_name ):
@@ -234,13 +258,24 @@ def measure_perf(trt_model, input_signature, backend_name):
234
258
arg_parser .add_argument (
235
259
"--benchmark" , action = "store_true" , help = "Enable benchmark (default: False)"
236
260
)
237
-
261
+ arg_parser .add_argument (
262
+ "--qformat" ,
263
+ help = ("Apply quantization format. Options: fp8 (default: None)" ),
264
+ default = None ,
265
+ )
266
+ arg_parser .add_argument (
267
+ "--pre_quantized" ,
268
+ action = "store_true" ,
269
+ help = "Use pre-quantized model weights (default: False)" ,
270
+ )
238
271
args = arg_parser .parse_args ()
239
272
with torch .inference_mode ():
240
273
model = get_model (args )
241
274
242
275
tokenizer = AutoTokenizer .from_pretrained (args .tokenizer or args .model )
243
-
276
+ # Set pad token
277
+ if tokenizer .pad_token is None :
278
+ tokenizer .pad_token = tokenizer .eos_token
244
279
# Prepare input for benchmarking or evaluation
245
280
if args .benchmark :
246
281
input_ids = torch .randint (
@@ -257,7 +292,8 @@ def measure_perf(trt_model, input_signature, backend_name):
257
292
pyt_gen_tokens = None
258
293
pyt_timings = None
259
294
pyt_stats = None
260
-
295
+ if args .qformat != None :
296
+ model = quantize_model (model , args , tokenizer )
261
297
if args .enable_pytorch_run :
262
298
pyt_gen_tokens = generate (
263
299
model , input_ids .clone (), MAX_OUTPUT_SEQ_LENGTH , tokenizer .eos_token_id
@@ -336,19 +372,41 @@ def measure_perf(trt_model, input_signature, backend_name):
336
372
batch_size = args .batch_size ,
337
373
compile_time_s = None ,
338
374
)
375
+ match_result = "N/A"
376
+ torch_out = "N/A"
377
+ model_name = args .model .replace ("/" , "_" )
378
+ qformat = args .qformat if args .qformat else "no_quant"
339
379
340
380
if not args .benchmark :
341
381
if args .enable_pytorch_run :
342
- print_outputs ("PyTorch" , pyt_gen_tokens , tokenizer )
382
+ torch_out = print_outputs ("PyTorch" , pyt_gen_tokens , tokenizer )
343
383
344
- print_outputs ("TensorRT" , trt_gen_tokens , tokenizer )
384
+ trt_out = print_outputs ("TensorRT" , trt_gen_tokens , tokenizer )
345
385
346
386
if args .enable_pytorch_run :
347
387
print (
348
388
f"PyTorch and TensorRT outputs match: { torch .equal (pyt_gen_tokens , trt_gen_tokens )} "
349
389
)
350
-
390
+ match_result = str (torch .equal (pyt_gen_tokens , trt_gen_tokens ))
391
+ out_json_file = f"{ model_name } _{ qformat } _match.json"
392
+ result = {}
393
+ result ["match" ] = match_result
394
+ result ["torch_out" ] = torch_out
395
+ result ["trt_out" ] = trt_out
396
+ with open (os .path .join ("result" , out_json_file ), "w" ) as f :
397
+ json .dump (result , f , indent = 4 )
398
+ print (f"Results saved to { out_json_file } " )
351
399
if args .benchmark :
400
+ result = {}
401
+ args_dict = vars (args )
402
+
403
+ result ["args" ] = args_dict
404
+ result ["pyt_stats" ] = pyt_stats if args .enable_pytorch_run else None
405
+ result ["trt_stats" ] = trt_stats if args .benchmark else None
406
+ out_json_file = f"{ model_name } _{ qformat } _benchmark.json"
407
+ with open (os .path .join ("result" , out_json_file ), "w" ) as f :
408
+ json .dump (result , f , indent = 4 )
409
+ print (f"Results saved to { out_json_file } " )
352
410
if args .enable_pytorch_run :
353
411
print ("=========PyTorch PERFORMANCE============ \n " )
354
412
print (pyt_stats )
0 commit comments