Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,9 @@ nncf-tests.xml
compressed_graph.dot
original_graph.dot
tests/post_training/**/*memory_logs

output_*
*eval*
debug*
*.pth
*.db
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
IR_DIR=u2_u4_ov_model
lm_eval \
--model openvino \
--model_args pretrained=$IR_DIR \
--device cpu \
--output_path ov_eval \
--limit 100 \
--tasks lambada_openai
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) 2026 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path

import torch
from optimum.exporters.openvino.convert import export_from_model
from torch import nn
from transformers import AutoModelForCausalLM

import nncf
from nncf.parameters import StripFormat
from nncf.torch.function_hook.wrapper import get_hook_storage
from nncf.torch.model_creation import load_from_config
from nncf.torch.quantization.layers import SymmetricLoraQuantizer # noqa: F401


def load_checkpoint(model: nn.Module, ckpt_file: Path) -> nn.Module:
"""
Loads the state of a tuned model from a checkpoint. This function restores the placement of Fake Quantizers (FQs)
with absorbable LoRA adapters and loads their parameters.

:param model: The model to load the checkpoint into.
:param ckpt_file: Path to the checkpoint file.
:returns: The model with the loaded NNCF state from checkpoint.
"""
ckpt = torch.load(ckpt_file, weights_only=False, map_location="cpu")
model = load_from_config(model, ckpt["nncf_config"])
if "model_state" in ckpt:
model.load_state_dict(ckpt["model_state"])
hook_storage = get_hook_storage(model)
hook_storage.load_state_dict(ckpt["nncf_state_dict"])
return model


pretrained = "Qwen/Qwen3-4B"
ckpt_file = "nncf_checkpoint_epoch10.pth"
ir_dir = "u2_u4_ov_model"
with torch.no_grad():
model_to_eval = AutoModelForCausalLM.from_pretrained(pretrained, torch_dtype=torch.float32, device_map="cpu")
model_to_eval = load_checkpoint(model_to_eval, ckpt_file)
model_to_eval = nncf.strip(model_to_eval, do_copy=False, strip_format=StripFormat.DQ)
export_from_model(model_to_eval, ir_dir, device="cpu")
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (c) 2026 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Synthetic test to verify INT2 symmetric decompression subgraph
can be exported to OpenVINO IR via torch.jit.trace + openvino.convert_model.
"""

import numpy as np
import openvino as ov
import torch


def pack_uint2(tensor: torch.Tensor) -> torch.Tensor:
packed_tensor = tensor.contiguous().reshape(-1, 4)
packed_tensor = (
torch.bitwise_and(packed_tensor[..., 0], 3)
| (torch.bitwise_and(packed_tensor[..., 1], 3) << 2)
| (torch.bitwise_and(packed_tensor[..., 2], 3) << 4)
| (torch.bitwise_and(packed_tensor[..., 3], 3) << 6)
)
return packed_tensor


def unpack_uint2(packed_tensor: torch.Tensor) -> torch.Tensor:
return torch.stack(
(
torch.bitwise_and(packed_tensor, 3),
torch.bitwise_and(torch.bitwise_right_shift(packed_tensor, 2), 3),
torch.bitwise_and(torch.bitwise_right_shift(packed_tensor, 4), 3),
torch.bitwise_and(torch.bitwise_right_shift(packed_tensor, 6), 3),
),
dim=-1,
)


def decompress_symmetric(input, scale):
input = input.type(dtype=scale.dtype)
return input * scale


class INT2SymmetricLinear(torch.nn.Module):
"""
A simple linear layer that uses INT2 symmetric weight decompression,
matching the NNCF INT2SymmetricWeightsDecompressor pattern.
"""

ZERO_POINT_VALUE = 2

def __init__(self, in_features, out_features, group_size):
super().__init__()
assert out_features % group_size == 0
ngroups = out_features // group_size

compressed_weight_shape = (ngroups, group_size, in_features)
scale_shape = (ngroups, 1, in_features)

# Random uint2 weights [0, 3]
rng = np.random.default_rng(seed=42)
raw_weights = rng.integers(0, 4, size=compressed_weight_shape, dtype=np.uint8)
scale = (rng.random(scale_shape, dtype=np.float32) * 2.0 - 1.0).astype(np.float32)

self.compressed_weight_shape = compressed_weight_shape
self.packed_weight = torch.nn.Parameter(pack_uint2(torch.from_numpy(raw_weights)), requires_grad=False)
self.register_buffer("_scale", torch.from_numpy(scale).to(torch.float16))
self.register_buffer("_zero_point", torch.tensor(self.ZERO_POINT_VALUE, dtype=torch.uint8))
self.result_shape = (out_features, in_features)
self.result_dtype = torch.float32

def forward(self, x):
# NNCF INT2 symmetric decompression pattern
w = unpack_uint2(self.packed_weight)
w = w.reshape(self.compressed_weight_shape)
w = w.type(dtype=self.result_dtype) - self._zero_point.type(dtype=self.result_dtype)
w = decompress_symmetric(w, self._scale)
w = w.reshape(self.result_shape)
w = w.type(dtype=self.result_dtype)
return torch.matmul(x, w.t())


class SmallModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = INT2SymmetricLinear(16, 32, group_size=4)
self.linear2 = INT2SymmetricLinear(32, 16, group_size=4)

def forward(self, x):
x = self.linear1(x)
x = torch.relu(x)
x = self.linear2(x)
return x


def main():
print("=== Synthetic INT2 export test ===")
model = SmallModel()
model.eval()

dummy_input = torch.randn(1, 16)

# Step 1: Convert to OpenVINO IR
print("[1/4] Converting to OpenVINO IR...")
ov_model = ov.convert_model(model, example_input=dummy_input)
print(" Conversion successful.")

# Step 2: Check u2 constants in the converted OV model
print("[2/4] Checking u2 constants in OV model...")
u2_constants = []
for op in ov_model.get_ordered_ops():
if op.get_type_name() == "Constant" and "uint2" in str(op.get_output_element_type(0)):
u2_constants.append(op)

expected_u2_count = 2 # one per INT2SymmetricLinear layer
print(f" Found {len(u2_constants)} u2 constant(s) (expected {expected_u2_count}).")
for c in u2_constants:
print(f" - {c.get_friendly_name()}: shape={c.get_output_partial_shape(0)}")
assert len(u2_constants) == expected_u2_count, f"Expected {expected_u2_count} u2 constants, got {len(u2_constants)}"
print(" PASSED - u2 constants detected.")

# Step 3: Save IR
ir_path = "/tmp/test_int2_synthetic_ir"
print(f"[3/4] Saving IR to {ir_path}...")
ov.save_model(ov_model, f"{ir_path}/model.xml")
print(" Save successful.")

# Step 4: Verify inference
print("[4/4] Running inference comparison...")
with torch.no_grad():
torch_out = model(dummy_input).numpy()

compiled = ov.Core().compile_model(ov_model, "CPU")
ov_out = compiled(dummy_input.numpy())[0]

max_diff = np.max(np.abs(torch_out - ov_out))
print(f" Max absolute difference: {max_diff:.6e}")
if max_diff < 1e-2:
print(" PASSED - Outputs match within tolerance.")
else:
print(f" WARNING - Large difference detected: {max_diff}")

print("\n=== All steps completed successfully! ===")


if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions src/nncf/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ class CompressWeightsMode(StrEnum):
:param INT4_ASYM: The same as INT4_SYM mode, but weights are quantized to a primary precision asymmetrically
with a typical non-fixed zero point.
https://github.com/openvinotoolkit/nncf/blob/develop/docs/usage/training_time_compression/other_algorithms/LegacyQuantization.md#asymmetric-quantization
:param INT2_SYM: Stands for 2-bit integer symmetric quantization without zero point.
Similar to INT4_SYM but with a 2-bit primary precision.
:param NF4: The the same as INT4_SYM mode, but primary precision is NF4 data type without zero point.
:param INT8: Mode is deprecated and will be removed in future releases. Please use `INT8_ASYM` instead.
:param MXFP4: MX-compliant FP4 format with E2M1 values sharing group-level E8M0 scale. The size of group is 32.
Expand All @@ -103,6 +105,7 @@ class CompressWeightsMode(StrEnum):
INT8_ASYM = "int8_asym"
INT4_SYM = "int4_sym"
INT4_ASYM = "int4_asym"
INT2_SYM = "int2_sym"
NF4 = "nf4"
CB4 = "cb4"
INT8 = "int8" # Deprecated mode
Expand Down
27 changes: 18 additions & 9 deletions src/nncf/quantization/algorithms/weight_compression/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,26 @@ def num_bits(self):
"""
:return: number of bits that is used for storing a single quantized value in the given mode.
"""
if self.mode in [
CompressWeightsMode.INT8_SYM,
CompressWeightsMode.INT8_ASYM,
CompressWeightsMode.FP8_E4M3,
CompressWeightsMode.MXFP8_E4M3,
]:
return 8
return 4
return {
CompressWeightsMode.INT8_SYM: 8,
CompressWeightsMode.INT8_ASYM: 8,
CompressWeightsMode.FP8_E4M3: 8,
CompressWeightsMode.MXFP8_E4M3: 8,
CompressWeightsMode.INT4_SYM: 4,
CompressWeightsMode.INT4_ASYM: 4,
CompressWeightsMode.NF4: 4,
CompressWeightsMode.MXFP4: 4,
CompressWeightsMode.FP4: 4,
CompressWeightsMode.CB4: 4,
CompressWeightsMode.INT2_SYM: 2,
}.get(self.mode, 4)

@property
def is_asym_mode(self):
return self.mode in [CompressWeightsMode.INT4_ASYM, CompressWeightsMode.INT8_ASYM]
return self.mode in [
CompressWeightsMode.INT4_ASYM,
CompressWeightsMode.INT8_ASYM,
]

@property
def is_integer(self):
Expand Down Expand Up @@ -101,6 +109,7 @@ def compression_dtype(self) -> TensorDataType:
dtype_per_mode = {
CompressWeightsMode.INT4_SYM: TensorDataType.int4,
CompressWeightsMode.INT4_ASYM: TensorDataType.uint4,
CompressWeightsMode.INT2_SYM: TensorDataType.int2,
CompressWeightsMode.INT8_ASYM: TensorDataType.uint8,
CompressWeightsMode.INT8_SYM: TensorDataType.int8,
CompressWeightsMode.NF4: TensorDataType.nf4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
CompressWeightsMode.INT8_SYM,
CompressWeightsMode.INT4_ASYM,
CompressWeightsMode.INT4_SYM,
CompressWeightsMode.INT2_SYM,
)

OPTIMIZED_COMPRESSION_COMPATIBLE_FLOAT_MODES = (
Expand Down
2 changes: 2 additions & 0 deletions src/nncf/tensor/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class TensorDataType(StrEnum):
uint8 = auto()
uint4 = auto()
int4 = auto()
int2 = auto()

def is_float(self) -> bool:
"""
Expand All @@ -78,6 +79,7 @@ def itemsize(self) -> int:
TensorDataType.nf4: 4,
TensorDataType.uint4: 4,
TensorDataType.int4: 4,
TensorDataType.int2: 2,
TensorDataType.f8e4m3: 8,
TensorDataType.f8e5m2: 8,
TensorDataType.int8: 8,
Expand Down
2 changes: 1 addition & 1 deletion src/nncf/torch/function_hook/strip.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def replace_quantizer_to_compressed_weight_with_decompressor(model: TModel) -> T
msg = ""
if hook_module._qspec.half_range or hook_module._qspec.narrow_range:
msg += "Unexpected parameters of quantizers on strip: half_range and narrow_range should be False.\n"
if hook_module.num_bits not in [4, 8]:
if hook_module.num_bits not in [2, 4, 8]:
msg += f"Unsupported number of bits {hook_module.num_bits} for the quantizer {hook_module}.\n"
if msg:
raise nncf.ValidationError(msg)
Expand Down
58 changes: 58 additions & 0 deletions src/nncf/torch/quantization/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@
from nncf.torch.quantization.quantize_functions import decompress_symmetric
from nncf.torch.quantization.quantize_functions import get_scale_zp_from_input_low_input_high
from nncf.torch.quantization.quantize_functions import pack_int4
from nncf.torch.quantization.quantize_functions import pack_uint2
from nncf.torch.quantization.quantize_functions import pack_uint4
from nncf.torch.quantization.quantize_functions import symmetric_quantize
from nncf.torch.quantization.quantize_functions import symmetric_quantize_lora
from nncf.torch.quantization.quantize_functions import unpack_int4
from nncf.torch.quantization.quantize_functions import unpack_uint2
from nncf.torch.quantization.quantize_functions import unpack_uint4
from nncf.torch.return_types import maybe_get_values_from_torch_return_type
from nncf.torch.return_types import maybe_wrap_to_torch_return_type
Expand Down Expand Up @@ -1467,6 +1469,62 @@ def forward(self, x):
return result


class INT2SymmetricWeightsDecompressor(BaseWeightsDecompressor):
"""
Applies symmetric decompression of 2-bit compressed weights in the forward pass.

Weights with values in [-2, -1, 0, 1] are stored as uint2 [0, 1, 2, 3] using
a hardcoded zero point of 2. Four uint2 values are packed into each uint8 byte.
"""

ZERO_POINT_VALUE = 2

def __init__(
self,
scale: torch.Tensor,
compressed_weight_shape: tuple[int, ...],
result_shape: tuple[int, ...] | None = None,
result_dtype: torch.dtype | None = None,
):
"""
:param scale: A scale in quantization scheme
:param compressed_weight_shape: A compressed weight shape
:param result_shape: (Optional) A shape that result should be reshaped to
:param result_dtype: (Optional) A data type that result should be cast to
"""
super().__init__()
self.register_buffer("_scale", scale.type(dtype=torch.float16))
self.register_buffer(
"_zero_point",
torch.tensor(self.ZERO_POINT_VALUE, dtype=torch.uint8),
)

self.compressed_weight_shape = compressed_weight_shape
self.result_shape = result_shape
self.result_dtype = result_dtype

@property
def quantization_mode(self) -> QuantizationMode:
return QuantizationMode.SYMMETRIC

def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
if torch.any((weight < 0) | (weight > 3)):
msg = "Weight values are not in [0, 3]."
raise ValueError(msg)
return pack_uint2(weight.type(dtype=torch.uint8))

def forward(self, x):
x = unpack_uint2(x)
x = x.reshape(self.compressed_weight_shape)

x = x.type(dtype=self.result_dtype) - self._zero_point.type(dtype=self.result_dtype)

result = decompress_symmetric(x, self._scale)
result = result.reshape(self.result_shape) if self.result_shape is not None else result
result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result
return result


@COMPRESSION_MODULES.register()
class SQMultiply(torch.nn.Module, StatefulModuleInterface):
SCALE_SHAPE_KEY = "scale_shape"
Expand Down
Loading
Loading