@@ -147,7 +147,6 @@ def __init__(
147147        self .key_value_input_names  =  [key  for  key  in  self .input_names  if  "key_values"  in  key ]
148148        self .key_value_output_names  =  [key  for  key  in  self .output_names  if  "present"  in  key ]
149149        # Keeping the original model for serialization 
150-         self ._original_model  =  self .model .clone () if  not  compile_only  else  None 
151150        self ._pkv_precision  =  Type .f32 
152151        self .next_beam_idx  =  None 
153152        self ._past_length  =  0 
@@ -197,6 +196,17 @@ def raise_error(model_prop, user_prop, name):
197196        if  not  self ._compile_only  and  enable_compilation :
198197            self .compile ()
199198
199+     @staticmethod  
200+     def  _get_model_with_updated_pkv_precision (model : openvino .Model , pkv_precision : Type ) ->  openvino .Model :
201+         ppp  =  PrePostProcessor (model )
202+         for  key  in  model .inputs :
203+             if  "past_key_values"  in  key .get_any_name () and  pkv_precision  !=  key .get_element_type ():
204+                 ppp .input (key .get_any_name ()).tensor ().set_element_type (pkv_precision )
205+         for  key  in  model .outputs :
206+             if  "present"  in  key .get_any_name () and  pkv_precision  !=  key .get_element_type ():
207+                 ppp .output (key .get_any_name ()).tensor ().set_element_type (pkv_precision )
208+         return  ppp .build ()
209+ 
200210    def  update_pkv_precision (self , force_fp32 = False ):
201211        if  not  self .use_cache  or  self .stateful  or  self ._compile_only :
202212            return 
@@ -216,20 +226,13 @@ def update_pkv_precision(self, force_fp32=False):
216226                if  inference_precision_hint  in  STR_TO_OV_TYPE :
217227                    pkv_precision  =  STR_TO_OV_TYPE [inference_precision_hint ]
218228
219-             ppp  =  PrePostProcessor (self .model )
220-             for  key  in  self .model .inputs :
221-                 if  "past_key_values"  in  key .get_any_name () and  pkv_precision  !=  key .get_element_type ():
222-                     ppp .input (key .get_any_name ()).tensor ().set_element_type (pkv_precision )
223-             for  key  in  self .model .outputs :
224-                 if  "present"  in  key .get_any_name () and  pkv_precision  !=  key .get_element_type ():
225-                     ppp .output (key .get_any_name ()).tensor ().set_element_type (pkv_precision )
226- 
227-             self .model  =  ppp .build ()
229+             self .model  =  self ._get_model_with_updated_pkv_precision (self .model , pkv_precision )
228230            self ._pkv_precision  =  pkv_precision 
231+             self .request  =  None 
229232        else :
230233            if  hasattr (self , "_pkv_precision" ) and  self ._pkv_precision  !=  Type .f32 :
234+                 self .model  =  self ._get_model_with_updated_pkv_precision (self .model , Type .f32 )
231235                self ._pkv_precision  =  Type .f32 
232-                 self .model  =  self ._original_model .clone ()
233236                if  self .is_dynamic :
234237                    self .model  =  self ._reshape (self .model , - 1 , - 1 )
235238                self .request  =  None 
@@ -248,7 +251,11 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
248251            raise  ValueError (
249252                "`save_pretrained()` is not supported with `compile_only` mode, please intialize model without this option" 
250253            )
251-         model_to_save  =  self .model  if  self ._pkv_precision  ==  Type .f32  else  self ._original_model 
254+         model_to_save  =  (
255+             self .model 
256+             if  self ._pkv_precision  ==  Type .f32 
257+             else  self ._get_model_with_updated_pkv_precision (self .model .clone (), Type .f32 )
258+         )
252259        dst_path  =  os .path .join (save_directory , OV_XML_FILE_NAME )
253260        openvino .save_model (model_to_save , dst_path , compress_to_fp16 = False )
254261
0 commit comments