-
Notifications
You must be signed in to change notification settings - Fork 192
[Draft] [5526696] Add kv cache quantization support for onnx quantization #486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -181,6 +181,7 @@ def quantize( | |
| calibrate_per_node: bool = False, | ||
| custom_ops_to_quantize: list[str] = [], | ||
| direct_io_types: bool = False, | ||
| kv_quant_mode: str = "NONE", | ||
| **kwargs, | ||
| ) -> onnx.ModelProto: | ||
| """Applies FP8 GEMM only quantization to an ONNX file. | ||
|
|
@@ -295,6 +296,8 @@ def quantize( | |
| # With ActivationSymmetric as True, MinMax calibration is equivalent to max calibration | ||
| else CalibrationMethod.MinMax | ||
| ), | ||
| intermediate_generated_files=intermediate_generated_files, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't get how KV-Cache quantization meta-data is used with int8/fp8 quantization. Can you please elaborate the flow? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Both int8/fp8 quantization call |
||
| kv_quant_mode=kv_quant_mode, | ||
| ) | ||
| intermediate_generated_files.append(tmp_onnx_path) | ||
| if use_external_data_format: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,6 +45,10 @@ | |
| get_tensor_producer_nodes, | ||
| ) | ||
| from modelopt.onnx.quantization.gs_patching import patch_gs_modules | ||
| from modelopt.onnx.quantization.kv_cache import ( | ||
| save_kv_cache_calib_data, | ||
| save_kv_cache_calib_data_rtn, | ||
| ) | ||
| from modelopt.onnx.quantization.ort_utils import create_inference_session | ||
| from modelopt.onnx.quantization.quant_utils import ( | ||
| _pad, | ||
|
|
@@ -444,6 +448,8 @@ def _quantize_awq_clip( | |
| force_fp16: bool = False, | ||
| nodes_to_exclude: list[str] = [], | ||
| input_shapes_profile: Sequence[dict[str, str]] | None = None, | ||
| intermediate_generated_files: list[str] = [], | ||
| kv_quant_mode: str = "NONE", | ||
| **kwargs: Any, | ||
| ) -> onnx.ModelProto: | ||
| """Quantizes `onnx_model` using the Activation aware quantization a.k.a AWQ algorithm.""" | ||
|
|
@@ -482,10 +488,19 @@ def _quantize_awq_clip( | |
| # Apply AWQ clip on selected weights | ||
| t = time.time() | ||
| alphas = {} | ||
|
|
||
| if kv_quant_mode != "NONE": | ||
| save_kv_cache_calib_data( | ||
| onnx_model, | ||
| session=session, | ||
| inputs=inputs, | ||
| intermediate_generated_files=intermediate_generated_files, | ||
| ) | ||
|
|
||
| for i in tqdm(range(len(wa_pack)), desc="Running clip search..."): | ||
| act_tensor, weight_tensor, do_transpose, gemm_io_type = wa_pack[i] | ||
|
|
||
| # First capture all the activation values after calibration data sweep | ||
| # First capture all the activation values after calibration data sweep | ||
| output_dicts = {} | ||
| for inp_d in inputs: | ||
| np_inp_d = {name: numpy.asarray(tensor) for name, tensor in inp_d.items()} | ||
|
|
@@ -968,6 +983,8 @@ def _quantize_awq_lite( | |
| use_zero_point: bool = False, | ||
| nodes_to_exclude: list[str] = [], | ||
| input_shapes_profile: Sequence[dict[str, str]] | None = None, | ||
| intermediate_generated_files: list[str] = [], | ||
| kv_quant_mode: str = "NONE", | ||
| **kwargs: Any, | ||
| ) -> onnx.ModelProto: | ||
| """Quantizes `onnx_model` using the Activation aware quantization a.k.a AWQ algorithm.""" | ||
|
|
@@ -1025,6 +1042,14 @@ def _quantize_awq_lite( | |
|
|
||
| gc.collect() | ||
|
|
||
| if kv_quant_mode != "NONE": | ||
| save_kv_cache_calib_data( | ||
| onnx_model, | ||
| session=session, | ||
| inputs=inputs, | ||
| intermediate_generated_files=intermediate_generated_files, | ||
| ) | ||
|
|
||
| output_data = [] | ||
|
|
||
| if enable_fast_path_using_high_sysram: | ||
|
|
@@ -1328,6 +1353,8 @@ def quantize( | |
| nodes_to_exclude: list[str] | None = [r"/lm_head"], | ||
| log_level: str = "INFO", | ||
| input_shapes_profile: Sequence[dict[str, str]] | None = None, | ||
| intermediate_generated_files: list[str] = [], | ||
| kv_quant_mode: str = "NONE", | ||
| **kwargs: Any, | ||
| ) -> onnx.ModelProto: | ||
| """Applies INT4 Weight-Only-Quantization (WoQ) to an ONNX model. | ||
|
|
@@ -1421,6 +1448,16 @@ def quantize( | |
| qdq.use_trt_qdq_ops() | ||
|
|
||
| if calibration_method in ["rtn", "rtn_dq", "rtn_trt", "rtn_trt_dq"]: | ||
| # Save kv-cache calibration data if kv_quant_mode is not NONE | ||
| if kv_quant_mode != "NONE": | ||
| save_kv_cache_calib_data_rtn( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For INT4 AWQ/RTN + 8-bit KV Cache quantization, can we avoid 2 session runs by preparing KV tensor names before creating augmented model, augmenting model for these KV tensors too, and post-processing for save-kv-cache-calib-data after AWQ/RTN loop? Just checking if we can avoid 2 session runs, and thereby speedup the combined quantization of matmul and kv-cache. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that's possible, but this change won't apply to int8/fp8 path and for awq_lite, awq_clip, rtn, they need to be implemented separately which means not much code can be reused. If you fell it's worthy, I can definitely implement in this way. |
||
| onnx_model, | ||
| intermediate_generated_files=intermediate_generated_files, | ||
| data_reader=calibration_data_reader, | ||
| calibration_eps=calibration_eps, | ||
| input_shapes_profile=input_shapes_profile, | ||
| use_external_data_format=use_external_data_format, | ||
| ) | ||
| onnx_model = quantize_rtn( | ||
| onnx_model, | ||
| block_size, | ||
|
|
@@ -1445,6 +1482,8 @@ def quantize( | |
| use_zero_point=use_zero_point, | ||
| enable_weight_clipping=do_weight_clipping, | ||
| input_shapes_profile=input_shapes_profile, | ||
| kv_quant_mode=kv_quant_mode, | ||
| intermediate_generated_files=intermediate_generated_files, | ||
| **kwargs, | ||
| ) | ||
| elif calibration_method in ["awq_clip", "awq_clip_trt"]: | ||
|
|
@@ -1456,6 +1495,8 @@ def quantize( | |
| block_size, | ||
| nodes_to_exclude=nodes_to_exclude, | ||
| input_shapes_profile=input_shapes_profile, | ||
| kv_quant_mode=kv_quant_mode, | ||
| intermediate_generated_files=intermediate_generated_files, | ||
| **kwargs, | ||
| ) | ||
| else: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.