99from PIL import Image
1010import torch
1111import openvino .runtime as ov
12+ from openvino import convert_model
1213
1314import tomeov
1415from diffusers import StableDiffusionPipeline , DDPMScheduler
1516from optimum .intel .openvino import OVStableDiffusionPipeline
17+ from optimum .exporters .openvino import export_from_model
1618import open_clip
1719import timm
1820
@@ -33,7 +35,7 @@ def test_stable_diffusion(self):
3335 tomeov .patch_stable_diffusion (loaded_pipeline , ratio = 0.3 )
3436
3537 with tempfile .TemporaryDirectory () as tmpdirname :
36- tomeov . export_diffusion_pipeline (loaded_pipeline , tmpdirname )
38+ export_from_model (loaded_pipeline , tmpdirname )
3739 ov_pipe = OVStableDiffusionPipeline .from_pretrained (tmpdirname , compile = False )
3840 ov_pipe .reshape (batch_size = 1 , height = height , width = width , num_images_per_prompt = 1 )
3941 ov_pipe .compile ()
@@ -42,26 +44,16 @@ def test_stable_diffusion(self):
4244 def test_openclip (self ):
4345 model , _ , transform = open_clip .create_model_and_transforms (self .OPENCLIP_MODEL [0 ], pretrained = self .OPENCLIP_MODEL [1 ])
4446 tomeov .patch_openclip (model , 8 )
45- dummy_image = np .random .rand (100 , 100 , 3 ) * 255
47+ dummy_image = np .random .rand (224 , 224 , 3 ) * 255
4648 dummy_image = Image .fromarray (dummy_image .astype ("uint8" ))
4749 dummy_image = transform (dummy_image ).unsqueeze (0 )
4850
49- with tempfile .TemporaryDirectory (suffix = ".onnx" ) as tmpdirname :
50- model_file = os .path .join (tmpdirname , "image_encoder.onnx" )
51- torch .onnx .export (
52- model .visual ,
53- dummy_image ,
54- model_file ,
55- opset_version = 14 ,
56- input_names = ["image" ],
57- output_names = ["image_embedding" ],
58- dynamic_axes = {
59- "image" : {0 : "batch" },
60- "image_embedding" : {0 : "batch" },
61- }
62- )
63- compiled_model = ov .compile_model (model_file )
64- self .assertTrue (compiled_model )
51+ ov_model = convert_model (
52+ model .visual ,
53+ example_input = dummy_image
54+ )
55+ compiled_model = ov .compile_model (ov_model )
56+ self .assertTrue (compiled_model )
6557
6658 def test_timm (self ):
6759 model = timm .create_model (self .TIMM_MODEL , pretrained = False )
0 commit comments