Skip to content

[WC][OV] NVFP4 support#3967

Open
daniil-lyakhov wants to merge 4 commits intoopenvinotoolkit:developfrom
daniil-lyakhov:dl/nvfp4_rev1
Open

[WC][OV] NVFP4 support#3967
daniil-lyakhov wants to merge 4 commits intoopenvinotoolkit:developfrom
daniil-lyakhov:dl/nvfp4_rev1

Conversation

@daniil-lyakhov
Copy link
Collaborator

@daniil-lyakhov daniil-lyakhov commented Mar 2, 2026

NVFP4 dtype is introduced:
f4e2m1 weight compression with constant group size 16
Scale is compressed to f8e4m3 using single fp32 second degree scale

Changes

  • CompressedWeight container is extended with second_degree_scale attribute
  • do float/int quantizaton/dequantization functions are updated to return Compressedweight instead of list of tensors to simplify the output of the functions (instead of returning 5 tensors the container with named attributes is returned)
  • OpenVINO WC backend is extended to insert NVFP4 compression subgraphs with 2 scales in it

Reason for changes

  • To support NVFP4 compression

Related tickets

Tests

  • tests/torch/function_hook/quantization/test_weights_compression.py and tests/onnx/quantization/test_weights_compression.py is updated to check that NVFP4 mode is raising non supported param error
  • In tests/openvino/native/quantization/test_weights_compression.py:
    ** test_compare_compressed_weights checks the subgraph is correct and scales/compressed weight are calculated correctly
    ** test_float_compressed_weighs_range check the do_float_quantization and do_float_dequantization are correct with NVFP4
    ** TestUnsupportedParams (+test_nvfp4_precomputed_scales) checks that no algorithm / group_size != 16 / fallback mode / precomputed scales are supported with NVFP4
    ** test_mixed_precision_fp checks the correctness of mixed precision algorithm with the NVFP4 (and correctness of group_size=16 param)

@github-actions github-actions bot added documentation Improvements or additions to documentation NNCF PT Pull requests that updates NNCF PyTorch NNCF OpenVINO Pull requests that updates NNCF OpenVINO NNCF ONNX Pull requests that updates NNCF ONNX API Public API-impacting changes labels Mar 2, 2026
@daniil-lyakhov daniil-lyakhov marked this pull request as ready for review March 3, 2026 13:47
@daniil-lyakhov daniil-lyakhov requested a review from a team as a code owner March 3, 2026 13:47
Copilot AI review requested due to automatic review settings March 3, 2026 13:47
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces NVFP4 (NVIDIA FP4) as a new weight compression dtype for OpenVINO backend only. NVFP4 uses f4e2m1 (E2M1) values with a fixed group size of 16, where each group has an f8e4m3 (E4M3) scale that is further quantized using a per-weight FP32 second-degree scale.

Changes:

  • CompressWeightsMode.NVFP4 is added as a new mode with group size 16, f4e2m1 compressed weights, and a two-level scale (f8e4m3 group scale + FP32 per-weight second-degree scale).
  • do_float_quantization/do_float_dequantization and do_integer_quantization/do_integer_dequantization are refactored to return/accept CompressedWeight instead of tuples of tensors, and CompressedWeight is extended with a second_degree_scale field.
  • The OpenVINO backend's _create_compression_subgraph inserts the NVFP4 two-scale dequantization subgraph.

Reviewed changes

Copilot reviewed 18 out of 18 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
src/nncf/parameters.py Adds NVFP4 = "nvfp4" to CompressWeightsMode enum
src/nncf/quantization/algorithms/weight_compression/parameters.py Adds second_degree_scale field to CompressedWeight; removes is_codebook() method
src/nncf/quantization/algorithms/weight_compression/config.py Maps NVFP4 to f4e2m1 compression dtype; marks it as non-integer
src/nncf/quantization/algorithms/weight_compression/weight_lowering.py Implements NVFP4 two-scale quantization/dequantization; refactors do_float/integer_quantization/dequantization return types
src/nncf/quantization/algorithms/weight_compression/algorithm.py Sets default group size 16 and validates NVFP4 group size constraint
src/nncf/quantization/algorithms/weight_compression/openvino_backend.py Inserts two-scale dequantization subgraph for NVFP4
src/nncf/quantization/algorithms/weight_compression/scale_estimation.py Updates call sites to use new CompressedWeight return types
src/nncf/quantization/algorithms/weight_compression/lora_correction.py Updates dequantization call sites
src/nncf/quantization/quantize_model.py Adds NVFP4 to unsupported mode lists for Torch/ONNX backends
src/nncf/openvino/optimized_functions/functions.py Updates optimized quantization functions to return CompressedWeight
tests/openvino/native/quantization/test_weights_compression.py Adds comprehensive NVFP4 tests: subgraph check, scale range check, unsupported params, mixed precision
tests/openvino/native/data/2026.0/reference_scales/IntegerModel_compressed_weights_nvfp4.json Reference data for NVFP4 compressed weight test
tests/openvino/optimized_functions/test_compression_functions.py Updates to use CompressedWeight object in assertions
tests/torch/function_hook/quantization/test_weights_compression.py Adds NVFP4 to unsupported modes for Torch
tests/onnx/quantization/test_weights_compression.py Adds NVFP4 to unsupported modes for ONNX
docs/usage/post_training_compression/weights_compression/Usage.md Documents NVFP4 format in the modes table
docs/Algorithms.md Mentions NVFP4 in supported types
.ci/cspell_dict.txt Adds "nvfp" to the spellcheck dictionary
Comments suppressed due to low confidence (3)

src/nncf/quantization/algorithms/weight_compression/weight_lowering.py:451

  • The docstring for do_integer_quantization still says :return: A tuple containing the compressed weights, scale, and zero point tensors., but the return type was changed from tuple[Tensor, Tensor, Tensor] to CompressedWeight. The return description should be updated to match the new return type.
) -> CompressedWeight:
    """
    Performs integer quantization on the given weight tensor.

    :param weight: The weight tensor to quantize.
    :param config: The weight compression configuration.
    :param reduction_axes: Axes along which to reduce (collect) statistics (e.g., min, max). Not required if
        precomputed scale (and zero point) are provided.
    :param precomputed_scale: Optional precomputed scale tensor.
    :param precomputed_zero_point: Optional precomputed zero point tensor.
    :return: A tuple containing the compressed weights, scale, and zero point tensors.

src/nncf/quantization/algorithms/weight_compression/weight_lowering.py:515

  • The docstring for integer_quantize_dequantize_weight still says :return: Dequantized weight tensor or a tuple containing the decompressed weight, compressed weight, scale, (and zero point)., but the return type was changed to Tensor | tuple[Tensor, CompressedWeight]. The description should be updated to reflect that when return_compressed_weight=True, only a (decompressed_weight, CompressedWeight) tuple is returned.
    """
    First quantizes the given weight tensor to integer dtype and then dequantizes it back to obtain float32 values.

    :param weight: The weight tensor to quantize-dequantize.
    :param config: Compression configuration.
    :param reduction_axes: Axes along which to reduce (collect) statistics (e.g., min, max). Not required if
        precomputed scale (and zero point) are provided.
    :param precomputed_scale: Optional precomputed scale tensor.
    :param precomputed_zero_point: Optional precomputed zero point tensor.
    :param return_compressed_weight: If True, besides decompressed weight will also return compressed weight, scale,
        (and zero point).
    :return: Dequantized weight tensor or a tuple containing the decompressed weight, compressed weight, scale,
        (and zero point).

src/nncf/quantization/algorithms/weight_compression/weight_lowering.py:250

  • The docstring for float_quantize_dequantize_weight still says :return: Dequantized weight tensor or a tuple containing the decompressed weight, compressed weight and scale. but the return type was changed to Tensor | tuple[Tensor, CompressedWeight]. The description should be updated to reflect that the tuple now contains (decompressed_weight, CompressedWeight) rather than (decompressed_weight, compressed_weight, scale).
) -> Tensor | tuple[Tensor, CompressedWeight]:
    """
    First quantizes the given weight tensor to float dtype and then dequantizes it back to obtain float32 values.

    :param weight: The weight tensor to quantize-dequantize.
    :param config: Compression configuration.
    :param reduction_axes: Axes along which to reduce statistics. Not required if precomputed scale are provided.
    :param precomputed_scale: Optional precomputed scale tensor.
    :param return_compressed_weight: If True, besides decompressed weight will also return compressed weight and scale.
    :return: Dequantized weight tensor or a tuple containing the decompressed weight, compressed weight and scale.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 18 out of 18 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

"""
The method dequantizes the given integer weights to float point data type in accordance with the scale and
zero_point data type.

Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring for do_integer_dequantization (line 415) documents :param compressed_weight: but the actual function parameter is named compressed_weights (note the plural). This is a mismatch between the docstring and the function signature.

Copilot uses AI. Check for mistakes.
Comment on lines +167 to +168
:return: CompressedWeight instance containing the compressed weight tensor, scale,
and optionally second degree scale or codebook with indexes.
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring at line 167-168 says :return: Returns quantized (for codebook normalized) weight tensor and corresponding scale tensor, optional second degree scale and optional indexes for codebook. This return description describes individual tuple elements, but the function now returns a single CompressedWeight object. The docstring should be updated to accurately reflect the new return type.

Suggested change
:return: CompressedWeight instance containing the compressed weight tensor, scale,
and optionally second degree scale or codebook with indexes.
:return: CompressedWeight instance encapsulating the compressed weight tensor and associated scale data.

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 18 out of 18 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


compressed_weight = do_integer_quantization(w, config, -1)

assert np.allclose(np.abs(compressed_weight.tensor.data), np.abs(w.data))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd check that such 2-scale decompression subgraph can be inferred by OpenVINO on CPU. single layer test would be enough.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made a small sanity test with a reference output, please check

:param zero_point: The zero-point, it is the value of the compression type corresponding to the value 0
in the non-compression realm. Applicable for INT quantization.
:param codebook: The codebook (LUT) for the weight compression. Applicable for vector quantization
:param second_degree_scale: The second degree scale used when the decompression scale itself is compressed.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it official name for this kind of scale? I've seen the terms "super-scale" or "super-block-scale" before.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the clode I found:

NVIDIA Model Optimizer (NVIDIA/Model-Optimizer) — Implementation of NVFP4 quantization showing global_amax, global_scale, weights_scaling_factor_2, and _double_scale:

https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/quantization/qtensor/nvfp4_tensor.py
https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/quantization/triton/fp4_kernel_hopper.py
https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/quantization/nn/modules/tensor_quantizer.py (class NVFP4StaticQuantizer)

I like the global scale name, what do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, global scale looks good for me.

| MXFP8_E4M3 | E4M3 | E8M0 | Group-wise (32) | [MX-compliant FP8](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) |
| FP8_E4M3 | E4M3 | FP16 | Per-channel / Group-wise | [FP8](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) |
| FP4 | E2M1 | FP16 | Per-channel / Group-wise | [FP4](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) |
| NVFP4 | E2M1 | E4M3 per group / FP32 per weight | Group-wise (16) | [NVFP4](https://www.arxiv.org/pdf/2602.14582) |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably it is link to wrong paper:
"YOLO26: A Comprehensive Architecture Overview and Key Improvements"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh, nice catch!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

API Public API-impacting changes Code Freeze documentation Improvements or additions to documentation NNCF ONNX Pull requests that updates NNCF ONNX NNCF OpenVINO Pull requests that updates NNCF OpenVINO NNCF PT Pull requests that updates NNCF PyTorch

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants