88from pathlib import Path
99from typing import Optional
1010
11+ import cv2
1112import gradio as gr
1213import numpy as np
1314import openvino as ov
2728
2829safety_checker : Optional [Pipeline ] = None
2930
30- ov_pipelines = {}
31+ ov_pipelines_t2i = {}
32+ ov_pipelines_i2i = {}
3133
3234stop_generating : bool = True
3335hf_model_name : Optional [str ] = None
@@ -39,15 +41,15 @@ def get_available_devices() -> list[str]:
3941 return list ({device .split ("." )[0 ] for device in core .available_devices if device != "NPU" })
4042
4143
42- def download_models (model_name : str , safety_checker_model : str ) -> None :
44+ def download_models (model_name , safety_checker_model : str ) -> None :
4345 global safety_checker
4446
4547 is_openvino_model = model_name .split ("/" )[0 ] == "OpenVINO"
4648
4749 output_dir = MODEL_DIR / model_name
4850 if not output_dir .exists ():
4951 if is_openvino_model :
50- snapshot_download (model_name , local_dir = output_dir )
52+ snapshot_download (model_name , local_dir = output_dir , resume_download = True )
5153 else :
5254 raise ValueError (f"Model { model_name } is not from OpenVINO Hub and not supported" )
5355
@@ -62,35 +64,51 @@ def download_models(model_name: str, safety_checker_model: str) -> None:
6264 image_processor = AutoProcessor .from_pretrained (safety_checker_dir ))
6365
6466
65- async def load_pipeline (model_name : str , device : str ):
66- if device not in ov_pipelines :
67- model_dir = MODEL_DIR / model_name
68- ov_config = {"CACHE_DIR" : "cache" }
67+ async def load_pipeline (model_name : str , device : str , pipeline : str ):
68+ model_dir = MODEL_DIR / model_name
69+ ov_config = {"CACHE_DIR" : "cache" }
6970
70- ov_pipeline = genai .Text2ImagePipeline (model_dir , device , ** ov_config )
71- ov_pipelines [device ] = ov_pipeline
71+ if pipeline == "text2image" :
72+ if device not in ov_pipelines_t2i :
73+ ov_pipeline = genai .Text2ImagePipeline (model_dir , device , ** ov_config )
74+ ov_pipelines_t2i [device ] = ov_pipeline
7275
73- return ov_pipelines [device ]
76+ return ov_pipelines_t2i [device ]
77+
78+ if pipeline == "image2image" :
79+ if device not in ov_pipelines_i2i :
80+ ov_pipeline = genai .Image2ImagePipeline (model_dir , device , ** ov_config )
81+ ov_pipelines_i2i [device ] = ov_pipeline
82+
83+ return ov_pipelines_i2i [device ]
7484
7585
7686async def stop ():
7787 global stop_generating
7888 stop_generating = True
7989
8090
81- async def generate_images (prompt : str , seed : int , size : int , guidance_scale : float , num_inference_steps : int , randomize_seed : bool , device : str , endless_generation : bool ) -> tuple [np .ndarray , float ]:
91+ async def generate_images (input_image : np .ndarray , prompt : str , seed : int , size : int , guidance_scale : float , num_inference_steps : int ,
92+ strength : float , randomize_seed : bool , device : str , endless_generation : bool ) -> tuple [np .ndarray , float ]:
8293 global stop_generating
8394 stop_generating = not endless_generation
8495
85- ov_pipeline = await load_pipeline (hf_model_name , device )
86-
8796 while True :
8897 if randomize_seed :
8998 seed = random .randint (0 , MAX_SEED )
9099
91100 start_time = time .time ()
92- result = ov_pipeline .generate (prompt = prompt , num_inference_steps = num_inference_steps , width = size , height = size ,
93- guidance_scale = guidance_scale , generator = genai .CppStdGenerator (seed )).data [0 ]
101+ if input_image is None :
102+ ov_pipeline = await load_pipeline (hf_model_name , device , "text2image" )
103+ result = ov_pipeline .generate (prompt = prompt , num_inference_steps = num_inference_steps , width = size , height = size ,
104+ guidance_scale = guidance_scale , rng_seed = seed ).data [0 ]
105+ else :
106+ ov_pipeline = await load_pipeline (hf_model_name , device , "image2image" )
107+ # ensure image is square
108+ input_image = utils .crop_center (input_image )
109+ input_image = cv2 .resize (input_image , (size , size ))
110+ result = ov_pipeline .generate (prompt = prompt , image = ov .Tensor (input_image [None ]), num_inference_steps = num_inference_steps , width = size , height = size ,
111+ guidance_scale = guidance_scale , strength = 1.0 - strength , rng_seed = seed ).data [0 ]
94112 end_time = time .time ()
95113
96114 label = safety_checker (Image .fromarray (result ), top_k = 1 )
@@ -113,13 +131,17 @@ async def generate_images(prompt: str, seed: int, size: int, guidance_scale: flo
113131
114132
115133def build_ui ():
116- examples = [
134+ examples_t2i = [
117135 "A sail boat on a grass field with mountains in the morning and sunny day" ,
136+ "A beautiful sunset with a sail boat on the ocean, photograph, highly detailed, golden hour, Nikon D850" ,
118137 "Portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour,"
119- "Style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography" ,
120- "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" ,
121- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" ,
122- "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece" ,
138+ "Style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography"
139+ ]
140+
141+ examples_i2i = [
142+ "Make me a super hero, 8k" ,
143+ "Make me a beautiful cyborg with golden hair, 8k" ,
144+ "Make me an astronaut, cold color palette, muted colors, 8k"
123145 ]
124146
125147 with gr .Blocks () as demo :
@@ -132,9 +154,11 @@ def build_ui():
132154 )
133155 with gr .Row ():
134156 with gr .Column ():
135- result_img = gr .Image (label = "Generated image" , elem_id = "output_image" , format = "png" )
136157 with gr .Row ():
137- result_time_label = gr .Text ("" , label = "Inference Time" , type = "text" )
158+ input_image = gr .Image (label = "Input image (leave blank for text2image generation)" , sources = ["webcam" , "clipboard" , "upload" ])
159+ result_img = gr .Image (label = "Generated image" , elem_id = "output_image" , format = "png" )
160+ with gr .Row ():
161+ result_time_label = gr .Text ("" , label = "Inference time" , type = "text" )
138162 with gr .Row ():
139163 start_button = gr .Button ("Start generation" )
140164 stop_button = gr .Button ("Stop generation" )
@@ -167,25 +191,43 @@ def build_ui():
167191 step = 1 ,
168192 value = 5 ,
169193 )
194+ with gr .Row ():
195+ strength_slider = gr .Slider (
196+ label = "Input image influence strength" ,
197+ minimum = 0.0 ,
198+ maximum = 1.0 ,
199+ step = 0.01 ,
200+ value = 0.5
201+ )
202+ size_slider = gr .Slider (
203+ label = "Image size" ,
204+ minimum = 256 ,
205+ maximum = 1024 ,
206+ step = 64 ,
207+ value = 512
208+ )
170209
171- size_slider = gr .Slider (
172- label = "Image size" ,
173- minimum = 256 ,
174- maximum = 1024 ,
175- step = 64 ,
176- value = 512
177- )
178-
179210 gr .Examples (
180- examples = examples ,
211+ label = "Examples for Text2Image" ,
212+ examples = examples_t2i ,
213+ inputs = prompt_text ,
214+ outputs = result_img ,
215+ cache_examples = False ,
216+ )
217+
218+ gr .Examples (
219+ label = "Examples for Image2Image" ,
220+ examples = examples_i2i ,
181221 inputs = prompt_text ,
182222 outputs = result_img ,
183223 cache_examples = False ,
184224 )
225+
185226 # clicking run
186227 gr .on (triggers = [prompt_text .submit , start_button .click ],
187228 fn = generate_images ,
188- inputs = [prompt_text , seed_slider , size_slider , guidance_scale_slider , num_inference_steps_slider , randomize_seed_checkbox , device_dropdown , endless_checkbox ],
229+ inputs = [input_image , prompt_text , seed_slider , size_slider , guidance_scale_slider , num_inference_steps_slider ,
230+ strength_slider , randomize_seed_checkbox , device_dropdown , endless_checkbox ],
189231 outputs = [result_img , result_time_label ]
190232 )
191233 # clicking stop
@@ -214,7 +256,7 @@ def run_endless_lcm(model_name: str, safety_checker_model: str, local_network: b
214256 parser = argparse .ArgumentParser ()
215257 parser .add_argument ("--model_name" , type = str , default = "OpenVINO/LCM_Dreamshaper_v7-fp16-ov" ,
216258 choices = ["OpenVINO/LCM_Dreamshaper_v7-int8-ov" , "OpenVINO/LCM_Dreamshaper_v7-fp16-ov" ],
217- help = "Visual GenAI model to be used" )
259+ help = "GenAI model to be used" )
218260 parser .add_argument ("--safety_checker_model" , type = str , default = "Falconsai/nsfw_image_detection" ,
219261 choices = ["Falconsai/nsfw_image_detection" ], help = "The model to verify if the generated image is NSFW" )
220262 parser .add_argument ("--local_network" , action = "store_true" , help = "Whether demo should be available in local network" )
0 commit comments