Skip to content

Conversation

@zhanghaoc
Copy link

What does this PR do?

Add kv cache quantization. Currently support int8/fp8 minMax calibration method.

Overview:

  • add new file kv_cache.py. Include a function to save calibration data, a function to read data and do scale calculations and finally add attributes and new inputs to onnx model.
  • other files' change only pass new parameters

Usage

python -m modelopt.onnx.quantization --onnx_path="C:\repos\models\Llama-3.2-3B-Instruct-ONNX\cuda\cuda-fp16\model.onnx" --quantize_mode=int4 --calibration_method=rtn_dq --kv_quant_mode=PER_TENSOR --output_path="C:\repos\models\Llama-3.2-3B-Instruct-ONNX\cuda\cuda-fp16\model.int4.rtn_dq.kv_cache.onnx" --log_level=DEBUG

Testing

Test not done, still waiting for feedback.

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: No
  • Did you add or update any necessary documentation?: Yes
  • Did you update Changelog?: No

Additional Information

@zhanghaoc zhanghaoc requested a review from a team as a code owner October 31, 2025 00:16
@zhanghaoc zhanghaoc requested a review from gcunhase October 31, 2025 00:16
@copy-pr-bot
Copy link

copy-pr-bot bot commented Oct 31, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@codecov
Copy link

codecov bot commented Oct 31, 2025

Codecov Report

❌ Patch coverage is 26.27737% with 101 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.02%. Comparing base (9e64f81) to head (a877d02).
⚠️ Report is 38 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/onnx/quantization/kv_cache.py 15.17% 95 Missing ⚠️
modelopt/onnx/quantization/int4.py 88.88% 2 Missing ⚠️
modelopt/onnx/quantization/ort_patching.py 33.33% 2 Missing ⚠️
modelopt/onnx/quantization/quantize.py 50.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #486      +/-   ##
==========================================
- Coverage   73.38%   73.02%   -0.36%     
==========================================
  Files         180      181       +1     
  Lines       17934    18260     +326     
==========================================
+ Hits        13160    13334     +174     
- Misses       4774     4926     +152     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: zhanghaoc <[email protected]>
@vishalpandya1990 vishalpandya1990 removed their assignment Oct 31, 2025
# call to_dict and save to json
with open(calib_data_path, "wb") as f:
pickle.dump(kv_tensor_data, f)
intermediate_generated_files.append(calib_data_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the memory impact (or other issues) if we keep the KV cache related data in variable instead of writing them to disk file?

f"Unsupported kv_cache_type {kv_cache_type} for kv cache quantization"
)

kv_tensor_names_list.sort()
Copy link
Contributor

@vishalpandya1990 vishalpandya1990 Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add an assert/exception/suitable-safe-early-return here if the input model is not GenAI based i.e. it doesn't have expected IO binding / names? (e.g. if this list is empty or if there are no GQA nodes seen etc.?)

I think we currently support 8-bit KV Cache with GenAI Builder exported ONNX LLMs only, right?

if calibration_method in ["rtn", "rtn_dq", "rtn_trt", "rtn_trt_dq"]:
# Save kv-cache calibration data if kv_quant_mode is not NONE
if kv_quant_mode != "NONE":
save_kv_cache_calib_data_rtn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For INT4 AWQ/RTN + 8-bit KV Cache quantization, can we avoid 2 session runs by preparing KV tensor names before creating augmented model, augmenting model for these KV tensors too, and post-processing for save-kv-cache-calib-data after AWQ/RTN loop?

Just checking if we can avoid 2 session runs, and thereby speedup the combined quantization of matmul and kv-cache.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's possible, but this change won't apply to int8/fp8 path and for awq_lite, awq_clip, rtn, they need to be implemented separately which means not much code can be reused. If you fell it's worthy, I can definitely implement in this way.

for output in onnx_model.graph.output:
if "present" in output.name:
kv_tensor_names_list.append(output.name)
if kv_cache_type == "fp8":
Copy link
Contributor

@vishalpandya1990 vishalpandya1990 Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can simplify this a bit by creating a map, with assert/valueError for unsupported type. Something like below:

output.type.tensor_type.elemt_type = output_type_map[kv_cache_type]

where output_type_map = {"int8": , "fp8": }

Possibly we can create a util for validating dtype, model, input model - whether it is currently supported or not.

# With ActivationSymmetric as True, MinMax calibration is equivalent to max calibration
else CalibrationMethod.MinMax
),
intermediate_generated_files=intermediate_generated_files,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't get how KV-Cache quantization meta-data is used with int8/fp8 quantization. Can you please elaborate the flow?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both int8/fp8 quantization call quantize_static from ort_patching. The change happens in ort_pathching.py. If kv_quant_mode is not None, it will save additional calibration data on disk.

node.input.append("")
node.input.append(k_scale.name)
node.input.append(v_scale.name)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if kv-quant-type is per-channel then things wont go well, since we are not supporting it but not checking / flagging it as well. Is it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the check

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants