Skip to content

Commit 208df56

Browse files
committed
fp8 pre-quantized model support
1 parent a93266a commit 208df56

File tree

2 files changed

+289
-34
lines changed

2 files changed

+289
-34
lines changed

tools/llm/run_llm.py

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import argparse
1111
import copy
12+
import json
1213
import os
1314
import timeit
1415
from contextlib import nullcontext
@@ -21,9 +22,11 @@
2122
from torchtrt_ext import register_sdpa
2223
from transformers import AutoModelForCausalLM, AutoTokenizer
2324
from utils import (
25+
convert_linear_to_tensorrt_quantized,
2426
export_llm,
2527
generate,
2628
generate_with_static_cache,
29+
quantize_model,
2730
record_stats,
2831
time_generate,
2932
)
@@ -48,6 +51,7 @@ def get_model(args):
4851
torch.nn.Module: The loaded and configured model ready for inference,
4952
moved to CUDA device with the specified precision
5053
"""
54+
5155
with torch.no_grad():
5256
model = (
5357
AutoModelForCausalLM.from_pretrained(
@@ -58,6 +62,8 @@ def get_model(args):
5862
.eval()
5963
.cuda()
6064
)
65+
if args.pre_quantized:
66+
model = convert_linear_to_tensorrt_quantized(model, args.model)
6167

6268
if args.precision == "FP16":
6369
model = model.to(torch.float16)
@@ -106,7 +112,23 @@ def compile_torchtrt(model, input_ids, args):
106112
else:
107113
enabled_precisions = {torch.float32}
108114

109-
with torch_tensorrt.logging.debug() if args.debug else nullcontext():
115+
qformat = "_q_" + args.qformat if args.qformat else ""
116+
117+
logging_dir = f"./{args.model}_{args.precision}{qformat}"
118+
# with torch_tensorrt.logging.debug() if args.debug else nullcontext():
119+
with (
120+
torch_tensorrt.dynamo.Debugger(
121+
"debug",
122+
logging_dir=logging_dir,
123+
# capture_fx_graph_after=["constant_fold"],
124+
# save_engine_profile=True,
125+
# profile_format="trex",
126+
engine_builder_monitor=False,
127+
# save_layer_info=True,
128+
)
129+
if args.debug
130+
else nullcontext()
131+
):
110132
trt_model = torch_tensorrt.dynamo.compile(
111133
ep,
112134
inputs=[input_ids, position_ids],
@@ -129,12 +151,14 @@ def print_outputs(backend_name, gen_tokens, tokenizer):
129151
"""
130152
Print the generated tokens from the model.
131153
"""
154+
out = tokenizer.decode(gen_tokens[0], skip_special_tokens=True)
132155
print(f"========= {backend_name} =========")
133156
print(
134157
f"{backend_name} model generated text: ",
135-
tokenizer.decode(gen_tokens[0], skip_special_tokens=True),
158+
out,
136159
)
137160
print("===================================")
161+
return out
138162

139163

140164
def measure_perf(trt_model, input_signature, backend_name):
@@ -234,13 +258,24 @@ def measure_perf(trt_model, input_signature, backend_name):
234258
arg_parser.add_argument(
235259
"--benchmark", action="store_true", help="Enable benchmark (default: False)"
236260
)
237-
261+
arg_parser.add_argument(
262+
"--qformat",
263+
help=("Apply quantization format. Options: fp8 (default: None)"),
264+
default=None,
265+
)
266+
arg_parser.add_argument(
267+
"--pre_quantized",
268+
action="store_true",
269+
help="Use pre-quantized model weights (default: False)",
270+
)
238271
args = arg_parser.parse_args()
239272
with torch.inference_mode():
240273
model = get_model(args)
241274

242275
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer or args.model)
243-
276+
# Set pad token
277+
if tokenizer.pad_token is None:
278+
tokenizer.pad_token = tokenizer.eos_token
244279
# Prepare input for benchmarking or evaluation
245280
if args.benchmark:
246281
input_ids = torch.randint(
@@ -257,7 +292,8 @@ def measure_perf(trt_model, input_signature, backend_name):
257292
pyt_gen_tokens = None
258293
pyt_timings = None
259294
pyt_stats = None
260-
295+
if args.qformat != None:
296+
model = quantize_model(model, args, tokenizer)
261297
if args.enable_pytorch_run:
262298
pyt_gen_tokens = generate(
263299
model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id
@@ -336,19 +372,41 @@ def measure_perf(trt_model, input_signature, backend_name):
336372
batch_size=args.batch_size,
337373
compile_time_s=None,
338374
)
375+
match_result = "N/A"
376+
torch_out = "N/A"
377+
model_name = args.model.replace("/", "_")
378+
qformat = args.qformat if args.qformat else "no_quant"
339379

340380
if not args.benchmark:
341381
if args.enable_pytorch_run:
342-
print_outputs("PyTorch", pyt_gen_tokens, tokenizer)
382+
torch_out = print_outputs("PyTorch", pyt_gen_tokens, tokenizer)
343383

344-
print_outputs("TensorRT", trt_gen_tokens, tokenizer)
384+
trt_out = print_outputs("TensorRT", trt_gen_tokens, tokenizer)
345385

346386
if args.enable_pytorch_run:
347387
print(
348388
f"PyTorch and TensorRT outputs match: {torch.equal(pyt_gen_tokens, trt_gen_tokens)}"
349389
)
350-
390+
match_result = str(torch.equal(pyt_gen_tokens, trt_gen_tokens))
391+
out_json_file = f"{model_name}_{qformat}_match.json"
392+
result = {}
393+
result["match"] = match_result
394+
result["torch_out"] = torch_out
395+
result["trt_out"] = trt_out
396+
with open(os.path.join("result", out_json_file), "w") as f:
397+
json.dump(result, f, indent=4)
398+
print(f"Results saved to {out_json_file}")
351399
if args.benchmark:
400+
result = {}
401+
args_dict = vars(args)
402+
403+
result["args"] = args_dict
404+
result["pyt_stats"] = pyt_stats if args.enable_pytorch_run else None
405+
result["trt_stats"] = trt_stats if args.benchmark else None
406+
out_json_file = f"{model_name}_{qformat}_benchmark.json"
407+
with open(os.path.join("result", out_json_file), "w") as f:
408+
json.dump(result, f, indent=4)
409+
print(f"Results saved to {out_json_file}")
352410
if args.enable_pytorch_run:
353411
print("=========PyTorch PERFORMANCE============ \n")
354412
print(pyt_stats)

0 commit comments

Comments
 (0)