-
Notifications
You must be signed in to change notification settings - Fork 192
[Draft] [5526696] Add kv cache quantization support for onnx quantization #486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: zhanghaoc <[email protected]>
Signed-off-by: zhanghaoc <[email protected]>
Signed-off-by: zhanghaoc <[email protected]>
Signed-off-by: zhanghaoc <[email protected]>
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #486 +/- ##
==========================================
- Coverage 73.38% 73.02% -0.36%
==========================================
Files 180 181 +1
Lines 17934 18260 +326
==========================================
+ Hits 13160 13334 +174
- Misses 4774 4926 +152 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: zhanghaoc <[email protected]>
| # call to_dict and save to json | ||
| with open(calib_data_path, "wb") as f: | ||
| pickle.dump(kv_tensor_data, f) | ||
| intermediate_generated_files.append(calib_data_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the memory impact (or other issues) if we keep the KV cache related data in variable instead of writing them to disk file?
| f"Unsupported kv_cache_type {kv_cache_type} for kv cache quantization" | ||
| ) | ||
|
|
||
| kv_tensor_names_list.sort() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add an assert/exception/suitable-safe-early-return here if the input model is not GenAI based i.e. it doesn't have expected IO binding / names? (e.g. if this list is empty or if there are no GQA nodes seen etc.?)
I think we currently support 8-bit KV Cache with GenAI Builder exported ONNX LLMs only, right?
| if calibration_method in ["rtn", "rtn_dq", "rtn_trt", "rtn_trt_dq"]: | ||
| # Save kv-cache calibration data if kv_quant_mode is not NONE | ||
| if kv_quant_mode != "NONE": | ||
| save_kv_cache_calib_data_rtn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For INT4 AWQ/RTN + 8-bit KV Cache quantization, can we avoid 2 session runs by preparing KV tensor names before creating augmented model, augmenting model for these KV tensors too, and post-processing for save-kv-cache-calib-data after AWQ/RTN loop?
Just checking if we can avoid 2 session runs, and thereby speedup the combined quantization of matmul and kv-cache.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's possible, but this change won't apply to int8/fp8 path and for awq_lite, awq_clip, rtn, they need to be implemented separately which means not much code can be reused. If you fell it's worthy, I can definitely implement in this way.
| for output in onnx_model.graph.output: | ||
| if "present" in output.name: | ||
| kv_tensor_names_list.append(output.name) | ||
| if kv_cache_type == "fp8": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can simplify this a bit by creating a map, with assert/valueError for unsupported type. Something like below:
output.type.tensor_type.elemt_type = output_type_map[kv_cache_type]
where output_type_map = {"int8": , "fp8": }
Possibly we can create a util for validating dtype, model, input model - whether it is currently supported or not.
| # With ActivationSymmetric as True, MinMax calibration is equivalent to max calibration | ||
| else CalibrationMethod.MinMax | ||
| ), | ||
| intermediate_generated_files=intermediate_generated_files, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't get how KV-Cache quantization meta-data is used with int8/fp8 quantization. Can you please elaborate the flow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both int8/fp8 quantization call quantize_static from ort_patching. The change happens in ort_pathching.py. If kv_quant_mode is not None, it will save additional calibration data on disk.
| node.input.append("") | ||
| node.input.append(k_scale.name) | ||
| node.input.append(v_scale.name) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if kv-quant-type is per-channel then things wont go well, since we are not supporting it but not checking / flagging it as well. Is it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add the check
What does this PR do?
Add kv cache quantization. Currently support int8/fp8 minMax calibration method.
Overview:
Usage
Testing
Test not done, still waiting for feedback.
Before your PR is "Ready for review"
Additional Information