-
Notifications
You must be signed in to change notification settings - Fork 162
Deprecate TRTLLM-build in examples #297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
c0b62b4
888a89e
34cc205
be355d1
92e6900
ed6e98b
ab26c4e
62f10a0
3a9a3dc
0c56584
a355999
48440c8
bfbcdcd
3b3d08b
1885e81
be09664
8c4a6da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -89,28 +89,20 @@ def auto_quantize( | |
qformat_list = qformat.split(",") | ||
assert qformat_list, "No quantization formats provided" | ||
# Check if all provided quantization formats are supported | ||
if args.export_fmt == "hf": | ||
assert all( | ||
qformat | ||
in [ | ||
"fp8", | ||
"int4_awq", | ||
"nvfp4", | ||
"nvfp4_awq", | ||
"w4a8_awq", | ||
"fp8_pb_wo", | ||
"w4a8_mxfp4_fp8", | ||
"nvfp4_mlp_only", | ||
] | ||
for qformat in qformat_list | ||
), ( | ||
"One or more quantization formats provided are not supported for unified checkpoint export" | ||
) | ||
else: | ||
assert all( | ||
qformat in ["fp8", "int8_sq", "int4_awq", "w4a8_awq", "nvfp4", "nvfp4_awq"] | ||
for qformat in qformat_list | ||
), "One or more quantization formats provided are not supported for tensorrt llm export" | ||
assert all( | ||
qformat | ||
in [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you think we can pull this list of supported qformats as a variable and reuse in other places (in the auto quantize section)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ACK. This PR is pretty large. I hope we can move the improvements to a following up |
||
"fp8", | ||
"int4_awq", | ||
"nvfp4", | ||
"nvfp4_awq", | ||
"w4a8_awq", | ||
"fp8_pb_wo", | ||
"w4a8_mxfp4_fp8", | ||
"nvfp4_mlp_only", | ||
] | ||
for qformat in qformat_list | ||
), "One or more quantization formats provided are not supported for unified checkpoint export" | ||
|
||
def loss_func(output, data): | ||
# For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast` | ||
|
@@ -219,27 +211,21 @@ def main(args): | |
"Quantization supports only one quantization format." | ||
) | ||
|
||
# Check arguments for unified_hf export format and set to default if unsupported arguments are provided | ||
if args.export_fmt == "hf": | ||
assert args.sparsity_fmt == "dense", ( | ||
f"Sparsity format {args.sparsity_fmt} not supported by unified export api." | ||
) | ||
|
||
if not args.auto_quantize_bits: | ||
assert ( | ||
args.qformat | ||
in [ | ||
"int4_awq", | ||
"fp8", | ||
"nvfp4", | ||
"nvfp4_awq", | ||
"w4a8_awq", | ||
"fp8_pb_wo", | ||
"w4a8_mxfp4_fp8", | ||
"nvfp4_mlp_only", | ||
] | ||
or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES | ||
), f"Quantization format {args.qformat} not supported for HF export path" | ||
if not args.auto_quantize_bits: | ||
assert ( | ||
args.qformat | ||
in [ | ||
"int4_awq", | ||
"fp8", | ||
"nvfp4", | ||
"nvfp4_awq", | ||
"w4a8_awq", | ||
"fp8_pb_wo", | ||
"w4a8_mxfp4_fp8", | ||
"nvfp4_mlp_only", | ||
] | ||
or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES | ||
), f"Quantization format {args.qformat} not supported for HF export path" | ||
|
||
Comment on lines
+214
to
229
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Assertion is ineffective; This currently lets unsupported qformats through. Tighten the check. - if not args.auto_quantize_bits:
- assert (
- args.qformat
- in [
- "int4_awq",
- "fp8",
- "nvfp4",
- "nvfp4_awq",
- "w4a8_awq",
- "fp8_pb_wo",
- "w4a8_mxfp4_fp8",
- "nvfp4_mlp_only",
- ]
- or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES
- ), f"Quantization format {args.qformat} not supported for HF export path"
+ if not args.auto_quantize_bits:
+ assert args.qformat in ALLOWED_UNIFIED_HF_QFORMATS, (
+ f"Quantization format {args.qformat} not supported for HF export path"
+ ) Note: If you intended to allow “KV-cache-only” quant, handle that as a separate branch instead of weakening this assert.
|
||
# If low memory mode is enabled, we compress the model while loading the HF checkpoint. | ||
calibration_only = False | ||
|
@@ -253,9 +239,6 @@ def main(args): | |
attn_implementation=args.attn_implementation, | ||
) | ||
else: | ||
assert args.export_fmt == "hf", ( | ||
"Low memory mode is only supported for exporting HF checkpoint." | ||
) | ||
assert args.qformat in QUANT_CFG_CHOICES, ( | ||
f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}" | ||
) | ||
|
@@ -600,34 +583,31 @@ def output_decode(generated_ids, input_shape): | |
setattr(model.config, "architectures", full_model_config.architectures) | ||
|
||
start_time = time.time() | ||
if args.export_fmt == "tensorrt_llm": | ||
if model_type in ["t5", "bart", "whisper"] or args.sparsity_fmt != "dense": | ||
warnings.warn( | ||
"Still exporting TensorRT-LLM checkpoints for models not supported by the TensorRT-LLM torch runtime." | ||
) | ||
|
||
# Move meta tensor back to device before exporting. | ||
remove_hook_from_module(model, recurse=True) | ||
|
||
dtype = None | ||
if "w4a8_awq" in args.qformat: | ||
# TensorRT-LLM w4a8 only support fp16 as the dtype. | ||
dtype = torch.float16 | ||
|
||
# For Gemma2-27B, TRT-LLM only works with bfloat16 as the dtype. | ||
if model_type == "gemma2": | ||
dtype = torch.bfloat16 | ||
|
||
export_tensorrt_llm_checkpoint( | ||
model, | ||
model_type, | ||
dtype=dtype, | ||
export_dir=export_path, | ||
inference_tensor_parallel=args.inference_tensor_parallel, | ||
inference_pipeline_parallel=args.inference_pipeline_parallel, | ||
) | ||
elif args.export_fmt == "hf": | ||
else: | ||
# Check arguments for unified_hf export format and set to default if unsupported arguments are provided | ||
assert args.sparsity_fmt == "dense", ( | ||
f"Sparsity format {args.sparsity_fmt} not supported by unified export api." | ||
) | ||
|
||
export_hf_checkpoint( | ||
full_model, | ||
export_dir=export_path, | ||
) | ||
else: | ||
raise NotImplementedError(f"{args.export_fmt} not supported") | ||
|
||
# Restore default padding and export the tokenizer as well. | ||
if tokenizer is not None: | ||
|
@@ -710,9 +690,9 @@ def output_decode(generated_ids, input_shape): | |
parser.add_argument( | ||
"--export_fmt", | ||
required=False, | ||
default="tensorrt_llm", | ||
default="hf", | ||
choices=["tensorrt_llm", "hf"], | ||
help=("Checkpoint export format"), | ||
help="Deprecated. Please avoid using this argument.", | ||
) | ||
parser.add_argument( | ||
"--trust_remote_code", | ||
|
@@ -767,6 +747,9 @@ def output_decode(generated_ids, input_shape): | |
|
||
args = parser.parse_args() | ||
|
||
if args.export_fmt != "hf": | ||
warnings.warn("Deprecated. --export_fmt will be ignored.") | ||
|
||
args.dataset = args.dataset.split(",") if args.dataset else None | ||
args.calib_size = [int(num_sample) for num_sample in args.calib_size.split(",")] | ||
main(args) |
Uh oh!
There was an error while loading. Please reload this page.