Skip to content

Commit e22b2fd

Browse files
authored
Fix nncf quantization for decoder models (#727)
* Fix nncf quantization for decoder models * add test * update op quant op * remove deprecated warning * update expected quantized * enable stateful * style
1 parent 1319d7b commit e22b2fd

File tree

3 files changed

+17
-15
lines changed

3 files changed

+17
-15
lines changed

optimum/intel/openvino/modeling_decoder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS
4343
from .configuration import _DEFAULT_4BIT_CONFIGS, OVConfig, OVWeightQuantizationConfig, _check_default_4bit_configs
4444
from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel
45-
from .utils import ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME, STR_TO_OV_TYPE
45+
from .utils import ONNX_WEIGHTS_NAME, OV_TO_NP_TYPE, OV_XML_FILE_NAME, STR_TO_OV_TYPE
4646

4747

4848
if TYPE_CHECKING:
@@ -409,6 +409,7 @@ def prepare_inputs(
409409
elif self.use_cache:
410410
for input_name in self.key_value_input_names:
411411
model_inputs = self.model.input(input_name)
412+
dtype = OV_TO_NP_TYPE[model_inputs.get_element_type().get_type_name()]
412413
shape = model_inputs.get_partial_shape()
413414
if self.config.model_type == "chatglm":
414415
shape[0] = 0
@@ -419,7 +420,7 @@ def prepare_inputs(
419420
shape[2] = 0
420421
else:
421422
shape[1] = 0
422-
inputs[input_name] = Tensor(model_inputs.get_element_type(), shape.get_shape())
423+
inputs[input_name] = np.empty([dim.get_length() for dim in shape], dtype=dtype)
423424
else:
424425
# past_key_values are not used explicitly, instead they are handled inside the model
425426
if past_key_values is None:

optimum/intel/openvino/quantization.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,6 @@ def _quantize_ovbasemodel(
347347
remove_unused_columns=remove_unused_columns,
348348
data_collator=data_collator,
349349
)
350-
351350
if self.model.export_feature == "text-generation" and self.model.use_cache:
352351
calibration_dataset = self._prepare_text_generation_dataset(
353352
quantization_config, calibration_dataloader
@@ -430,6 +429,7 @@ def _quantize_ovbasemodel(
430429
),
431430
**kwargs,
432431
)
432+
433433
self.model.model = quantized_model
434434
if save_directory is not None:
435435
self.model.save_pretrained(save_directory)
@@ -696,24 +696,23 @@ def _prepare_builtin_dataset(self, quantization_config: OVWeightQuantizationConf
696696
def _prepare_text_generation_dataset(
697697
self, quantization_config: OVQuantizationConfig, calibration_dataloader: OVDataLoader
698698
) -> nncf.Dataset:
699-
# TODO: this function is not covered by tests, remove if not relevant anymore or cover by tests otherwise
700-
701699
# Prefetch past_key_values
702700
self.model.update_pkv_precision(True)
703701
self.model.compile()
704702
collected_inputs = []
705703

706704
num_samples = quantization_config.num_samples or 200
707705

708-
self.model.request = InferRequestWrapper(self.model.model.request, collected_inputs)
706+
self.model.request = InferRequestWrapper(self.model.request, collected_inputs)
709707
try:
710708
for data in calibration_dataloader:
711709
self.model.generate(**data, max_new_tokens=1)
712710
if len(collected_inputs) >= num_samples:
713711
break
714712
finally:
715-
self.model.model.request = self.model.model.request.request
713+
self.model.request = self.model.request.request
716714
calibration_dataset = nncf.Dataset(collected_inputs)
715+
717716
return calibration_dataset
718717

719718
def _prepare_unet_dataset(

tests/openvino/test_quantization.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,16 @@
7373

7474

7575
class OVQuantizerTest(unittest.TestCase):
76-
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = (
76+
SUPPORTED_ARCHITECTURES_TORCH_MODEL = (
7777
(OVModelForSequenceClassification, "bert", 32, 35),
78-
# (OVModelForCausalLM, "gpt2", 41, 23),
78+
(OVModelForCausalLM, "gpt2", 41, 3),
79+
)
80+
SUPPORTED_ARCHITECTURES_OV_MODEL = (
81+
(OVModelForSequenceClassification, "bert", 32, 35),
82+
(OVModelForCausalLM, "gpt2", 31, 22),
7983
)
8084

81-
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS)
85+
@parameterized.expand(SUPPORTED_ARCHITECTURES_TORCH_MODEL)
8286
def test_automodel_static_quantization(self, model_cls, model_name, expected_fake_quantize, expected_int8):
8387
model_id = MODEL_NAMES[model_name]
8488
task = model_cls.export_feature
@@ -123,23 +127,21 @@ def preprocess_function(examples, tokenizer):
123127
loaded_config = OVConfig.from_pretrained(tmp_dir)
124128
self.assertEqual(ov_config.quantization_config.to_dict(), loaded_config.quantization_config.to_dict())
125129

126-
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS)
130+
@parameterized.expand(SUPPORTED_ARCHITECTURES_OV_MODEL)
127131
def test_ovmodel_static_quantization(self, model_cls, model_name, expected_fake_quantize, expected_int8):
128132
model_id = MODEL_NAMES[model_name]
129133
task = model_cls.export_feature
130134
dataset_name, dataset_config_name, column_name = _TASK_TO_DATASET[task]
131-
if "gpt2" in model_id:
132-
expected_int8 -= 1
133135

134136
def preprocess_function(examples, tokenizer):
135137
return tokenizer(examples[column_name], padding="max_length", max_length=128, truncation=True)
136138

137139
with tempfile.TemporaryDirectory() as tmp_dir:
138-
transformers_model = model_cls.from_pretrained(model_id, export=True)
140+
ov_model = model_cls.from_pretrained(model_id, export=True)
139141
tokenizer = AutoTokenizer.from_pretrained(model_id)
140142
if tokenizer.pad_token is None:
141143
tokenizer.pad_token = tokenizer.eos_token
142-
quantizer = OVQuantizer.from_pretrained(transformers_model, task=task)
144+
quantizer = OVQuantizer.from_pretrained(ov_model, task=task)
143145

144146
calibration_dataset = quantizer.get_calibration_dataset(
145147
dataset_name,

0 commit comments

Comments
 (0)