@@ -579,6 +579,38 @@ def test_find_files_matching_pattern_with_quantized_ov_model(self):
579579 ov_files = _find_files_matching_pattern (local_dir , pattern = pattern , subfolder = subfolder )
580580 self .assertTrue (len (ov_files ) == 1 )
581581
582+ def test_load_from_hub_onnx_model_and_save (self ):
583+ model_id = "katuni4ka/tiny-random-LlamaForCausalLM-onnx"
584+ tokenizer = AutoTokenizer .from_pretrained (model_id )
585+ tokens = tokenizer ("This is a sample input" , return_tensors = "pt" )
586+ loaded_model = OVModelForCausalLM .from_pretrained (model_id , from_onnx = True )
587+ self .assertIsInstance (loaded_model .config , PretrainedConfig )
588+ # Test that PERFORMANCE_HINT is set to LATENCY by default
589+ self .assertEqual (loaded_model .ov_config .get ("PERFORMANCE_HINT" ), "LATENCY" )
590+ self .assertEqual (loaded_model .request .get_compiled_model ().get_property ("PERFORMANCE_HINT" ), "LATENCY" )
591+ loaded_model_outputs = loaded_model (** tokens )
592+
593+ with TemporaryDirectory () as tmpdirname :
594+ loaded_model .save_pretrained (tmpdirname )
595+ folder_contents = os .listdir (tmpdirname )
596+ self .assertTrue (OV_XML_FILE_NAME in folder_contents )
597+ self .assertTrue (OV_XML_FILE_NAME .replace (".xml" , ".bin" ) in folder_contents )
598+ model = OVModelForCausalLM .from_pretrained (tmpdirname )
599+ self .assertEqual (model .use_cache , loaded_model .use_cache )
600+
601+ compile_only_model = OVModelForCausalLM .from_pretrained (tmpdirname , compile_only = True )
602+ self .assertIsInstance (compile_only_model .model , ov .runtime .CompiledModel )
603+ self .assertIsInstance (compile_only_model .request , ov .runtime .InferRequest )
604+ outputs = compile_only_model (** tokens )
605+ self .assertTrue (torch .equal (loaded_model_outputs .logits , outputs .logits ))
606+ del compile_only_model
607+
608+ outputs = model (** tokens )
609+ self .assertTrue (torch .equal (loaded_model_outputs .logits , outputs .logits ))
610+ del loaded_model
611+ del model
612+ gc .collect ()
613+
582614
583615class PipelineTest (unittest .TestCase ):
584616 def test_load_model_from_hub (self ):
0 commit comments