Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion modelopt/onnx/quantization/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_parser() -> argparse.ArgumentParser:
argparser.add_argument(
"--calibration_method",
type=str,
choices=["max", "entropy", "awq_clip", "rtn_dq"],
choices=["max", "entropy", "awq_clip", "rtn_dq", "awq_full", "awq_lite", "rtn"],
help=(
"Calibration method choices for int8/fp8: {entropy (default), max}, "
"int4: {awq_clip (default), rtn_dq}."
Expand Down Expand Up @@ -255,6 +255,22 @@ def get_parser() -> argparse.ArgumentParser:
"The currently supported precisions are {fp16, int8, fp8}."
),
)
argparser.add_argument(
"--kv_quant_mode",
type=str,
choices=["NONE", "PER_TENSOR", "PER_CHANNEL"],
default="NONE",
help=(
"Quantization mode for kv cache in GQA. NONE (default) means no quantization for kv cache, "
),
)
argparser.add_argument(
"--kv_cache_type",
type=str,
choices=["fp8", "int8"],
default="NONE",
help=("Quantization type for kv cache in GQA. fp8 is default."),
)
return argparser


Expand Down Expand Up @@ -298,6 +314,8 @@ def main():
simplify=args.simplify,
calibrate_per_node=args.calibrate_per_node,
direct_io_types=args.direct_io_types,
kv_quant_mode=args.kv_quant_mode,
kv_cache_type=args.kv_cache_type,
)


Expand Down
3 changes: 3 additions & 0 deletions modelopt/onnx/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def quantize(
calibrate_per_node: bool = False,
custom_ops_to_quantize: list[str] = [],
direct_io_types: bool = False,
kv_quant_mode: str = "NONE",
**kwargs,
) -> onnx.ModelProto:
"""Applies FP8 GEMM only quantization to an ONNX file.
Expand Down Expand Up @@ -295,6 +296,8 @@ def quantize(
# With ActivationSymmetric as True, MinMax calibration is equivalent to max calibration
else CalibrationMethod.MinMax
),
intermediate_generated_files=intermediate_generated_files,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't get how KV-Cache quantization meta-data is used with int8/fp8 quantization. Can you please elaborate the flow?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both int8/fp8 quantization call quantize_static from ort_patching. The change happens in ort_pathching.py. If kv_quant_mode is not None, it will save additional calibration data on disk.

kv_quant_mode=kv_quant_mode,
)
intermediate_generated_files.append(tmp_onnx_path)
if use_external_data_format:
Expand Down
43 changes: 42 additions & 1 deletion modelopt/onnx/quantization/int4.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@
get_tensor_producer_nodes,
)
from modelopt.onnx.quantization.gs_patching import patch_gs_modules
from modelopt.onnx.quantization.kv_cache import (
save_kv_cache_calib_data,
save_kv_cache_calib_data_rtn,
)
from modelopt.onnx.quantization.ort_utils import create_inference_session
from modelopt.onnx.quantization.quant_utils import (
_pad,
Expand Down Expand Up @@ -444,6 +448,8 @@ def _quantize_awq_clip(
force_fp16: bool = False,
nodes_to_exclude: list[str] = [],
input_shapes_profile: Sequence[dict[str, str]] | None = None,
intermediate_generated_files: list[str] = [],
kv_quant_mode: str = "NONE",
**kwargs: Any,
) -> onnx.ModelProto:
"""Quantizes `onnx_model` using the Activation aware quantization a.k.a AWQ algorithm."""
Expand Down Expand Up @@ -482,10 +488,19 @@ def _quantize_awq_clip(
# Apply AWQ clip on selected weights
t = time.time()
alphas = {}

if kv_quant_mode != "NONE":
save_kv_cache_calib_data(
onnx_model,
session=session,
inputs=inputs,
intermediate_generated_files=intermediate_generated_files,
)

for i in tqdm(range(len(wa_pack)), desc="Running clip search..."):
act_tensor, weight_tensor, do_transpose, gemm_io_type = wa_pack[i]

# First capture all the activation values after calibration data sweep
# First capture all the activation values after calibration data sweep
output_dicts = {}
for inp_d in inputs:
np_inp_d = {name: numpy.asarray(tensor) for name, tensor in inp_d.items()}
Expand Down Expand Up @@ -968,6 +983,8 @@ def _quantize_awq_lite(
use_zero_point: bool = False,
nodes_to_exclude: list[str] = [],
input_shapes_profile: Sequence[dict[str, str]] | None = None,
intermediate_generated_files: list[str] = [],
kv_quant_mode: str = "NONE",
**kwargs: Any,
) -> onnx.ModelProto:
"""Quantizes `onnx_model` using the Activation aware quantization a.k.a AWQ algorithm."""
Expand Down Expand Up @@ -1025,6 +1042,14 @@ def _quantize_awq_lite(

gc.collect()

if kv_quant_mode != "NONE":
save_kv_cache_calib_data(
onnx_model,
session=session,
inputs=inputs,
intermediate_generated_files=intermediate_generated_files,
)

output_data = []

if enable_fast_path_using_high_sysram:
Expand Down Expand Up @@ -1328,6 +1353,8 @@ def quantize(
nodes_to_exclude: list[str] | None = [r"/lm_head"],
log_level: str = "INFO",
input_shapes_profile: Sequence[dict[str, str]] | None = None,
intermediate_generated_files: list[str] = [],
kv_quant_mode: str = "NONE",
**kwargs: Any,
) -> onnx.ModelProto:
"""Applies INT4 Weight-Only-Quantization (WoQ) to an ONNX model.
Expand Down Expand Up @@ -1421,6 +1448,16 @@ def quantize(
qdq.use_trt_qdq_ops()

if calibration_method in ["rtn", "rtn_dq", "rtn_trt", "rtn_trt_dq"]:
# Save kv-cache calibration data if kv_quant_mode is not NONE
if kv_quant_mode != "NONE":
save_kv_cache_calib_data_rtn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For INT4 AWQ/RTN + 8-bit KV Cache quantization, can we avoid 2 session runs by preparing KV tensor names before creating augmented model, augmenting model for these KV tensors too, and post-processing for save-kv-cache-calib-data after AWQ/RTN loop?

Just checking if we can avoid 2 session runs, and thereby speedup the combined quantization of matmul and kv-cache.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's possible, but this change won't apply to int8/fp8 path and for awq_lite, awq_clip, rtn, they need to be implemented separately which means not much code can be reused. If you fell it's worthy, I can definitely implement in this way.

onnx_model,
intermediate_generated_files=intermediate_generated_files,
data_reader=calibration_data_reader,
calibration_eps=calibration_eps,
input_shapes_profile=input_shapes_profile,
use_external_data_format=use_external_data_format,
)
onnx_model = quantize_rtn(
onnx_model,
block_size,
Expand All @@ -1445,6 +1482,8 @@ def quantize(
use_zero_point=use_zero_point,
enable_weight_clipping=do_weight_clipping,
input_shapes_profile=input_shapes_profile,
kv_quant_mode=kv_quant_mode,
intermediate_generated_files=intermediate_generated_files,
**kwargs,
)
elif calibration_method in ["awq_clip", "awq_clip_trt"]:
Expand All @@ -1456,6 +1495,8 @@ def quantize(
block_size,
nodes_to_exclude=nodes_to_exclude,
input_shapes_profile=input_shapes_profile,
kv_quant_mode=kv_quant_mode,
intermediate_generated_files=intermediate_generated_files,
**kwargs,
)
else:
Expand Down
3 changes: 3 additions & 0 deletions modelopt/onnx/quantization/int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def quantize(
calibrate_per_node: bool = False,
custom_ops_to_quantize: list[str] = [],
direct_io_types: bool = False,
kv_quant_mode: str = "NONE",
**kwargs,
) -> onnx.ModelProto:
"""Applies INT8 quantization to an ONNX file using the compiler friendly heuristics.
Expand Down Expand Up @@ -257,6 +258,8 @@ def quantize(
# With ActivationSymmetric as True, MinMax calibration is equivalent to max calibration
else CalibrationMethod.MinMax
),
intermediate_generated_files=intermediate_generated_files,
kv_quant_mode=kv_quant_mode,
)

intermediate_generated_files.append(tmp_onnx_path)
Expand Down
Loading
Loading