-
Notifications
You must be signed in to change notification settings - Fork 171
[NVBUG: 5535437] Refactor engine_dir to checkpoint_dir in PTQ examples. #365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -118,7 +118,7 @@ def run_eval( | |||||||||||||||||||||||
max_gpu_memory, | ||||||||||||||||||||||||
dtype, | ||||||||||||||||||||||||
revision, | ||||||||||||||||||||||||
engine_dir, | ||||||||||||||||||||||||
checkpoint_dir, | ||||||||||||||||||||||||
nim_model, | ||||||||||||||||||||||||
args, | ||||||||||||||||||||||||
): | ||||||||||||||||||||||||
|
@@ -150,7 +150,7 @@ def run_eval( | |||||||||||||||||||||||
revision=revision, | ||||||||||||||||||||||||
top_p=top_p, | ||||||||||||||||||||||||
temperature=temperature, | ||||||||||||||||||||||||
engine_dir=engine_dir, | ||||||||||||||||||||||||
checkpoint_dir=checkpoint_dir, | ||||||||||||||||||||||||
nim_model=nim_model, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
Comment on lines
+153
to
155
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major Pass explicit flags to remote worker; avoid relying on globals inside Ray When use_ray is true, get_model_answers runs out-of-process. Pass required options explicitly to avoid NameError on globals. Apply: checkpoint_dir=checkpoint_dir,
+ trust_remote_code=args.trust_remote_code,
+ quant_cfg=args.quant_cfg,
+ calib_batch_size=args.calib_batch_size,
+ calib_size=args.calib_size,
+ auto_quantize_bits=args.auto_quantize_bits,
nim_model=nim_model, 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||
for i in range(0, len(questions), chunk_size) | ||||||||||||||||||||||||
|
@@ -174,25 +174,22 @@ def get_model_answers( | |||||||||||||||||||||||
revision, | ||||||||||||||||||||||||
top_p=None, | ||||||||||||||||||||||||
temperature=None, | ||||||||||||||||||||||||
engine_dir=None, | ||||||||||||||||||||||||
checkpoint_dir=None, | ||||||||||||||||||||||||
nim_model=None, | ||||||||||||||||||||||||
): | ||||||||||||||||||||||||
# Model Optimizer modification | ||||||||||||||||||||||||
if engine_dir: | ||||||||||||||||||||||||
tokenizer = get_tokenizer(model_path, trust_remote_code=args.trust_remote_code) | ||||||||||||||||||||||||
if engine_dir: | ||||||||||||||||||||||||
# get model type | ||||||||||||||||||||||||
last_part = os.path.basename(engine_dir) | ||||||||||||||||||||||||
model_type = last_part.split("_")[0] | ||||||||||||||||||||||||
# Some models require to set pad_token and eos_token based on external config (e.g., qwen) | ||||||||||||||||||||||||
if model_type == "qwen": | ||||||||||||||||||||||||
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(151643) | ||||||||||||||||||||||||
tokenizer.eos_token = tokenizer.convert_ids_to_tokens(151643) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
assert LLM is not None, "tensorrt_llm APIs could not be imported." | ||||||||||||||||||||||||
model = LLM(engine_dir, tokenizer=tokenizer) | ||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
raise ValueError("engine_dir is required for TensorRT LLM inference.") | ||||||||||||||||||||||||
tokenizer = get_tokenizer(model_path, trust_remote_code=args.trust_remote_code) | ||||||||||||||||||||||||
if checkpoint_dir: | ||||||||||||||||||||||||
# get model type | ||||||||||||||||||||||||
last_part = os.path.basename(checkpoint_dir) | ||||||||||||||||||||||||
model_type = last_part.split("_")[0] | ||||||||||||||||||||||||
# Some models require to set pad_token and eos_token based on external config (e.g., qwen) | ||||||||||||||||||||||||
if model_type == "qwen": | ||||||||||||||||||||||||
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(151643) | ||||||||||||||||||||||||
tokenizer.eos_token = tokenizer.convert_ids_to_tokens(151643) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
assert LLM is not None, "tensorrt_llm APIs could not be imported." | ||||||||||||||||||||||||
model = LLM(checkpoint_dir, tokenizer=tokenizer) | ||||||||||||||||||||||||
elif not nim_model: | ||||||||||||||||||||||||
model, _ = load_model( | ||||||||||||||||||||||||
model_path, | ||||||||||||||||||||||||
|
@@ -205,7 +202,6 @@ def get_model_answers( | |||||||||||||||||||||||
cpu_offloading=False, | ||||||||||||||||||||||||
debug=False, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
tokenizer = get_tokenizer(model_path, trust_remote_code=args.trust_remote_code) | ||||||||||||||||||||||||
if args.quant_cfg: | ||||||||||||||||||||||||
quantize_model( | ||||||||||||||||||||||||
model, | ||||||||||||||||||||||||
|
@@ -259,7 +255,7 @@ def get_model_answers( | |||||||||||||||||||||||
|
||||||||||||||||||||||||
# some models may error out when generating long outputs | ||||||||||||||||||||||||
try: | ||||||||||||||||||||||||
if not engine_dir: | ||||||||||||||||||||||||
if not checkpoint_dir: | ||||||||||||||||||||||||
output_ids = model.generate( | ||||||||||||||||||||||||
torch.as_tensor(input_ids).cuda(), | ||||||||||||||||||||||||
do_sample=do_sample, | ||||||||||||||||||||||||
|
@@ -427,9 +423,9 @@ def reorg_answer_file(answer_file): | |||||||||||||||||||||||
help="The model revision to load.", | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
parser.add_argument( | ||||||||||||||||||||||||
"--engine-dir", | ||||||||||||||||||||||||
"--checkpoint-dir", | ||||||||||||||||||||||||
type=str, | ||||||||||||||||||||||||
help="The path to the TensorRT LLM engine directory.", | ||||||||||||||||||||||||
help="The path to the model checkpoint directory.", | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
parser.add_argument( | ||||||||||||||||||||||||
"--nim-model", | ||||||||||||||||||||||||
|
@@ -502,7 +498,7 @@ def reorg_answer_file(answer_file): | |||||||||||||||||||||||
max_gpu_memory=args.max_gpu_memory, | ||||||||||||||||||||||||
dtype=str_to_torch_dtype(args.dtype), | ||||||||||||||||||||||||
revision=args.revision, | ||||||||||||||||||||||||
engine_dir=args.engine_dir, | ||||||||||||||||||||||||
checkpoint_dir=args.checkpoint_dir, | ||||||||||||||||||||||||
nim_model=args.nim_model, | ||||||||||||||||||||||||
args=args, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
🧩 Analysis chain
Rename to checkpoint_dir in run_eval: looks good
Parameter rename is consistent.
Run to confirm no stale uses remain:
🏁 Script executed:
Length of output: 148
🏁 Script executed:
Length of output: 5278
Stale engine_dir references in Windows accuracy benchmark examples
--engine_dir
with--checkpoint-dir
engine_dir
tocheckpoint_dir
parser.add_argument("--engine_dir",…)
and all helper functions to usecheckpoint_dir
instead🤖 Prompt for AI Agents