Skip to content

Commit 4f3ec52

Browse files
Update pkv precision at save_pretrained call (#1235)
* Update pkv precision at save_pretrained call * Trigger tests
1 parent 68f6752 commit 4f3ec52

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

optimum/intel/openvino/modeling_decoder.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,6 @@ def __init__(
147147
self.key_value_input_names = [key for key in self.input_names if "key_values" in key]
148148
self.key_value_output_names = [key for key in self.output_names if "present" in key]
149149
# Keeping the original model for serialization
150-
self._original_model = self.model.clone() if not compile_only else None
151150
self._pkv_precision = Type.f32
152151
self.next_beam_idx = None
153152
self._past_length = 0
@@ -197,6 +196,17 @@ def raise_error(model_prop, user_prop, name):
197196
if not self._compile_only and enable_compilation:
198197
self.compile()
199198

199+
@staticmethod
200+
def _get_model_with_updated_pkv_precision(model: openvino.Model, pkv_precision: Type) -> openvino.Model:
201+
ppp = PrePostProcessor(model)
202+
for key in model.inputs:
203+
if "past_key_values" in key.get_any_name() and pkv_precision != key.get_element_type():
204+
ppp.input(key.get_any_name()).tensor().set_element_type(pkv_precision)
205+
for key in model.outputs:
206+
if "present" in key.get_any_name() and pkv_precision != key.get_element_type():
207+
ppp.output(key.get_any_name()).tensor().set_element_type(pkv_precision)
208+
return ppp.build()
209+
200210
def update_pkv_precision(self, force_fp32=False):
201211
if not self.use_cache or self.stateful or self._compile_only:
202212
return
@@ -216,20 +226,13 @@ def update_pkv_precision(self, force_fp32=False):
216226
if inference_precision_hint in STR_TO_OV_TYPE:
217227
pkv_precision = STR_TO_OV_TYPE[inference_precision_hint]
218228

219-
ppp = PrePostProcessor(self.model)
220-
for key in self.model.inputs:
221-
if "past_key_values" in key.get_any_name() and pkv_precision != key.get_element_type():
222-
ppp.input(key.get_any_name()).tensor().set_element_type(pkv_precision)
223-
for key in self.model.outputs:
224-
if "present" in key.get_any_name() and pkv_precision != key.get_element_type():
225-
ppp.output(key.get_any_name()).tensor().set_element_type(pkv_precision)
226-
227-
self.model = ppp.build()
229+
self.model = self._get_model_with_updated_pkv_precision(self.model, pkv_precision)
228230
self._pkv_precision = pkv_precision
231+
self.request = None
229232
else:
230233
if hasattr(self, "_pkv_precision") and self._pkv_precision != Type.f32:
234+
self.model = self._get_model_with_updated_pkv_precision(self.model, Type.f32)
231235
self._pkv_precision = Type.f32
232-
self.model = self._original_model.clone()
233236
if self.is_dynamic:
234237
self.model = self._reshape(self.model, -1, -1)
235238
self.request = None
@@ -248,7 +251,11 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
248251
raise ValueError(
249252
"`save_pretrained()` is not supported with `compile_only` mode, please intialize model without this option"
250253
)
251-
model_to_save = self.model if self._pkv_precision == Type.f32 else self._original_model
254+
model_to_save = (
255+
self.model
256+
if self._pkv_precision == Type.f32
257+
else self._get_model_with_updated_pkv_precision(self.model.clone(), Type.f32)
258+
)
252259
dst_path = os.path.join(save_directory, OV_XML_FILE_NAME)
253260
openvino.save_model(model_to_save, dst_path, compress_to_fp16=False)
254261

0 commit comments

Comments
 (0)