Skip to content

Commit a20051d

Browse files
authored
fix checking available files if from_onnx=True (#1208)
* fix checking available files if from_onnx=True * add test * fix diffusers loading * restore old export guessing behavior * Update optimum/intel/openvino/modeling_base.py
1 parent 4256352 commit a20051d

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

optimum/intel/openvino/modeling_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def from_pretrained(
459459

460460
ov_files = _find_files_matching_pattern(
461461
model_dir,
462-
pattern=r"(.*)?openvino(.*)?\_model(.*)?.xml$",
462+
pattern=r"(.*)?openvino(.*)?\_model(.*)?.xml$" if not kwargs.get("from_onnx", False) else "*.onnx",
463463
subfolder=subfolder,
464464
use_auth_token=token,
465465
revision=revision,

tests/openvino/test_modeling.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,38 @@ def test_find_files_matching_pattern_with_quantized_ov_model(self):
579579
ov_files = _find_files_matching_pattern(local_dir, pattern=pattern, subfolder=subfolder)
580580
self.assertTrue(len(ov_files) == 1)
581581

582+
def test_load_from_hub_onnx_model_and_save(self):
583+
model_id = "katuni4ka/tiny-random-LlamaForCausalLM-onnx"
584+
tokenizer = AutoTokenizer.from_pretrained(model_id)
585+
tokens = tokenizer("This is a sample input", return_tensors="pt")
586+
loaded_model = OVModelForCausalLM.from_pretrained(model_id, from_onnx=True)
587+
self.assertIsInstance(loaded_model.config, PretrainedConfig)
588+
# Test that PERFORMANCE_HINT is set to LATENCY by default
589+
self.assertEqual(loaded_model.ov_config.get("PERFORMANCE_HINT"), "LATENCY")
590+
self.assertEqual(loaded_model.request.get_compiled_model().get_property("PERFORMANCE_HINT"), "LATENCY")
591+
loaded_model_outputs = loaded_model(**tokens)
592+
593+
with TemporaryDirectory() as tmpdirname:
594+
loaded_model.save_pretrained(tmpdirname)
595+
folder_contents = os.listdir(tmpdirname)
596+
self.assertTrue(OV_XML_FILE_NAME in folder_contents)
597+
self.assertTrue(OV_XML_FILE_NAME.replace(".xml", ".bin") in folder_contents)
598+
model = OVModelForCausalLM.from_pretrained(tmpdirname)
599+
self.assertEqual(model.use_cache, loaded_model.use_cache)
600+
601+
compile_only_model = OVModelForCausalLM.from_pretrained(tmpdirname, compile_only=True)
602+
self.assertIsInstance(compile_only_model.model, ov.runtime.CompiledModel)
603+
self.assertIsInstance(compile_only_model.request, ov.runtime.InferRequest)
604+
outputs = compile_only_model(**tokens)
605+
self.assertTrue(torch.equal(loaded_model_outputs.logits, outputs.logits))
606+
del compile_only_model
607+
608+
outputs = model(**tokens)
609+
self.assertTrue(torch.equal(loaded_model_outputs.logits, outputs.logits))
610+
del loaded_model
611+
del model
612+
gc.collect()
613+
582614

583615
class PipelineTest(unittest.TestCase):
584616
def test_load_model_from_hub(self):

0 commit comments

Comments
 (0)