Skip to content

Commit 2ea7592

Browse files
committed
Update
Signed-off-by: Chenjie Luo <[email protected]>
1 parent faa456a commit 2ea7592

File tree

1 file changed

+3
-12
lines changed

1 file changed

+3
-12
lines changed

examples/llm_ptq/hf_ptq.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -584,25 +584,16 @@ def output_decode(generated_ids, input_shape):
584584

585585
start_time = time.time()
586586
if model_type in ["t5", "bart", "whisper"] or args.sparsity_fmt != "dense":
587-
# Still export TensorRT-LLM checkpoints for the models not supported by the
588-
# TensorRT-LLM torch runtime.
587+
warnings.warn(
588+
"Still exporting TensorRT-LLM checkpoints for models not supported by the TensorRT-LLM torch runtime."
589+
)
589590

590591
# Move meta tensor back to device before exporting.
591592
remove_hook_from_module(model, recurse=True)
592593

593-
dtype = None
594-
if "w4a8_awq" in args.qformat:
595-
# TensorRT-LLM w4a8 only support fp16 as the dtype.
596-
dtype = torch.float16
597-
598-
# For Gemma2-27B, TRT-LLM only works with bfloat16 as the dtype.
599-
if model_type == "gemma2":
600-
dtype = torch.bfloat16
601-
602594
export_tensorrt_llm_checkpoint(
603595
model,
604596
model_type,
605-
dtype=dtype,
606597
export_dir=export_path,
607598
inference_tensor_parallel=args.inference_tensor_parallel,
608599
inference_pipeline_parallel=args.inference_pipeline_parallel,

0 commit comments

Comments
 (0)