Skip to content

Commit 41637d0

Browse files
Introduce support for NF4 data type for OV weight compression (#988)
* Add NF4 weight format * remove test * Update optimum/intel/openvino/configuration.py Co-authored-by: Nikita Savelyev <[email protected]> * Update optimum/intel/openvino/configuration.py Co-authored-by: Nikita Savelyev <[email protected]> * Add extra checks * apply black --------- Co-authored-by: Nikita Savelyev <[email protected]>
1 parent 5c879b9 commit 41637d0

File tree

6 files changed

+25
-10
lines changed

6 files changed

+25
-10
lines changed

optimum/commands/export/openvino.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def parse_args_openvino(parser: "ArgumentParser"):
7171
optional_group.add_argument(
7272
"--weight-format",
7373
type=str,
74-
choices=["fp32", "fp16", "int8", "int4", "mxfp4"],
74+
choices=["fp32", "fp16", "int8", "int4", "mxfp4", "nf4"],
7575
default=None,
7676
help="The weight format of the exported model.",
7777
)

optimum/intel/openvino/configuration.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ class OVWeightQuantizationConfig(OVQuantizationConfigBase):
347347
Indicates whether to apply a scale estimation algorithm that minimizes the L2 error between the original and
348348
compressed layers. Providing a dataset is required to run scale estimation.
349349
weight_format (`str`, defaults to 'int'):
350-
Data format weights are compressed to. Possible values: ['int4', 'int8', 'mxfp4'].
350+
Data format weights are compressed to. Possible values: ['int4', 'int8', 'mxfp4', 'nf4'].
351351
qptq (`bool`, *optional*):
352352
Whether to apply GPTQ algorithm. GPTQ optimizes compressed weights in a layer-wise fashion to minimize the
353353
difference between activations of a compressed and original layer. Dataset is required to run GPTQ.
@@ -455,20 +455,22 @@ def post_init(self):
455455

456456
if self.weight_format is None:
457457
self.weight_format = "int4" if self.bits == 4 else "int8"
458-
if self.weight_format not in ["int4", "int8", "mxfp4"]:
458+
if self.weight_format not in ["int4", "int8", "mxfp4", "nf4"]:
459459
raise ValueError(
460-
f"Weight format must be one of the following: ['int4', 'int8', 'mxfp4'], but found: {self.weight_format}."
460+
f"Weight format must be one of the following: ['int4', 'int8', 'mxfp4', 'nf4'], but found: {self.weight_format}."
461461
)
462-
if self.weight_format == "mxfp4":
462+
if self.weight_format in ["mxfp4", "nf4"]:
463463
if self.bits != 4:
464464
raise ValueError(
465-
f"When applying weight compression with 'mxfp4' weight format the `bits` parameters must be set to 4, but found {self.bits}"
465+
f"When applying weight compression with '{self.weight_format}' weight format, the `bits` parameter must be set to 4, but found {self.bits}"
466466
)
467467
if self.quant_method == OVQuantizationMethod.AWQ:
468-
raise ValueError("The AWQ algorithm is not supported for 'mxfp4' weight format")
468+
raise ValueError(f"The AWQ algorithm is not supported for '{self.weight_format}' weight format")
469469
if self.scale_estimation:
470-
raise ValueError("The Scale Estimation algorithm is not supported for 'mxfp4' weight format")
471-
if self.gptq:
470+
raise ValueError(
471+
f"The Scale Estimation algorithm is not supported for '{self.weight_format}' weight format"
472+
)
473+
if self.weight_format == "mxfp4" and self.gptq:
472474
raise ValueError("The GPTQ algorithm is not supported for 'mxfp4' weight format")
473475

474476

optimum/intel/openvino/quantization.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,8 @@ def _weight_only_quantization(
930930

931931
if config.weight_format == "mxfp4":
932932
mode = CompressWeightsMode.E2M1
933+
elif config.weight_format == "nf4":
934+
mode = CompressWeightsMode.NF4
933935
else:
934936
if config.bits == 8:
935937
mode = CompressWeightsMode.INT8_SYM if config.sym else CompressWeightsMode.INT8_ASYM

tests/openvino/test_exporters_cli.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ class OVCLIExportTestCase(unittest.TestCase):
108108
("text-generation-with-past", "opt125m", "int4 --sym --group-size 128", {"int8": 4, "int4": 72}),
109109
("text-generation-with-past", "opt125m", "int4 --group-size 64", {"int8": 4, "int4": 144}),
110110
("text-generation-with-past", "opt125m", "mxfp4", {"int8": 4, "f4e2m1": 72, "f8e8m0": 72}),
111+
("text-generation-with-past", "opt125m", "nf4", {"int8": 4, "nf4": 72}),
111112
("text-generation-with-past", "llama_awq", "int4 --ratio 1.0 --sym --group-size 8 --all-layers", {"int4": 16}),
112113
(
113114
"text-generation-with-past",
@@ -267,7 +268,7 @@ def test_exporters_cli_hybrid_quantization(self, model_type: str, exp_num_fq: in
267268
self.assertEqual(exp_num_fq, num_fq)
268269

269270
@parameterized.expand(TEST_4BIT_CONFIGURATIONS)
270-
def test_exporters_cli_int4(self, task: str, model_type: str, option: str, expected_num_weight_nodes: dict):
271+
def test_exporters_cli_4bit(self, task: str, model_type: str, option: str, expected_num_weight_nodes: dict):
271272
with TemporaryDirectory() as tmpdir:
272273
result = subprocess.run(
273274
f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --weight-format {option} {tmpdir}",

tests/openvino/test_quantization.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,13 @@ class OVWeightCompressionTest(unittest.TestCase):
206206
dict(bits=4, weight_format="mxfp4", group_size=32),
207207
{"f4e2m1": 20, "f8e8m0": 20, "int8": 4},
208208
),
209+
(
210+
OVModelForCausalLM,
211+
"gpt2",
212+
False,
213+
dict(bits=4, weight_format="nf4", group_size=32),
214+
{"nf4": 20, "int8": 4},
215+
),
209216
(
210217
OVModelForCausalLM,
211218
"gpt2",

tests/openvino/utils_tests.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def get_num_quantized_nodes(model):
195195
"int4": 0,
196196
"f4e2m1": 0,
197197
"f8e8m0": 0,
198+
"nf4": 0,
198199
}
199200
ov_model = model if isinstance(model, ov.Model) else model.model
200201
for elem in ov_model.get_ops():
@@ -210,4 +211,6 @@ def get_num_quantized_nodes(model):
210211
num_weight_nodes["f4e2m1"] += 1
211212
if type_name == "f8e8m0":
212213
num_weight_nodes["f8e8m0"] += 1
214+
if type_name == "nf4":
215+
num_weight_nodes["nf4"] += 1
213216
return num_fake_quantize, num_weight_nodes

0 commit comments

Comments
 (0)