Skip to content

Commit c140395

Browse files
Add image2image pipeline to paint your dreams demo (#179)
* Update main.py Add input image field to UI (not used yet) and download and convert dreamlike_anime 1.0 model with some logic added for using two different model (text2image, and image2image) setting up for adding image2image pipeline later on. * Add image2image pipeline Add image2image code to exist along side text2image code. If input image is left blank, text2image pipeline will be used. Also added some example prompts for image2image * Update requirements Update requirements for image2image pipeline * Small refactor * Changed UI * Use one model for both tasks * Crop and resize the image, use strength parameter * Removed support for models outside OV hub * Updated OV version --------- Co-authored-by: whitneyfoster <whitney.foster@intel.com>
1 parent 541ddf2 commit c140395

File tree

3 files changed

+84
-35
lines changed

3 files changed

+84
-35
lines changed

demos/paint_your_dreams_demo/main.py

Lines changed: 75 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pathlib import Path
99
from typing import Optional
1010

11+
import cv2
1112
import gradio as gr
1213
import numpy as np
1314
import openvino as ov
@@ -27,7 +28,8 @@
2728

2829
safety_checker: Optional[Pipeline] = None
2930

30-
ov_pipelines = {}
31+
ov_pipelines_t2i = {}
32+
ov_pipelines_i2i = {}
3133

3234
stop_generating: bool = True
3335
hf_model_name: Optional[str] = None
@@ -39,15 +41,15 @@ def get_available_devices() -> list[str]:
3941
return list({device.split(".")[0] for device in core.available_devices if device != "NPU"})
4042

4143

42-
def download_models(model_name: str, safety_checker_model: str) -> None:
44+
def download_models(model_name, safety_checker_model: str) -> None:
4345
global safety_checker
4446

4547
is_openvino_model = model_name.split("/")[0] == "OpenVINO"
4648

4749
output_dir = MODEL_DIR / model_name
4850
if not output_dir.exists():
4951
if is_openvino_model:
50-
snapshot_download(model_name, local_dir=output_dir)
52+
snapshot_download(model_name, local_dir=output_dir, resume_download=True)
5153
else:
5254
raise ValueError(f"Model {model_name} is not from OpenVINO Hub and not supported")
5355

@@ -62,35 +64,51 @@ def download_models(model_name: str, safety_checker_model: str) -> None:
6264
image_processor=AutoProcessor.from_pretrained(safety_checker_dir))
6365

6466

65-
async def load_pipeline(model_name: str, device: str):
66-
if device not in ov_pipelines:
67-
model_dir = MODEL_DIR / model_name
68-
ov_config = {"CACHE_DIR": "cache"}
67+
async def load_pipeline(model_name: str, device: str, pipeline: str):
68+
model_dir = MODEL_DIR / model_name
69+
ov_config = {"CACHE_DIR": "cache"}
6970

70-
ov_pipeline = genai.Text2ImagePipeline(model_dir, device, **ov_config)
71-
ov_pipelines[device] = ov_pipeline
71+
if pipeline == "text2image":
72+
if device not in ov_pipelines_t2i:
73+
ov_pipeline = genai.Text2ImagePipeline(model_dir, device, **ov_config)
74+
ov_pipelines_t2i[device] = ov_pipeline
7275

73-
return ov_pipelines[device]
76+
return ov_pipelines_t2i[device]
77+
78+
if pipeline == "image2image":
79+
if device not in ov_pipelines_i2i:
80+
ov_pipeline = genai.Image2ImagePipeline(model_dir, device, **ov_config)
81+
ov_pipelines_i2i[device] = ov_pipeline
82+
83+
return ov_pipelines_i2i[device]
7484

7585

7686
async def stop():
7787
global stop_generating
7888
stop_generating = True
7989

8090

81-
async def generate_images(prompt: str, seed: int, size: int, guidance_scale: float, num_inference_steps: int, randomize_seed: bool, device: str, endless_generation: bool) -> tuple[np.ndarray, float]:
91+
async def generate_images(input_image: np.ndarray, prompt: str, seed: int, size: int, guidance_scale: float, num_inference_steps: int,
92+
strength: float, randomize_seed: bool, device: str, endless_generation: bool) -> tuple[np.ndarray, float]:
8293
global stop_generating
8394
stop_generating = not endless_generation
8495

85-
ov_pipeline = await load_pipeline(hf_model_name, device)
86-
8796
while True:
8897
if randomize_seed:
8998
seed = random.randint(0, MAX_SEED)
9099

91100
start_time = time.time()
92-
result = ov_pipeline.generate(prompt=prompt, num_inference_steps=num_inference_steps, width=size, height=size,
93-
guidance_scale=guidance_scale, generator=genai.CppStdGenerator(seed)).data[0]
101+
if input_image is None:
102+
ov_pipeline = await load_pipeline(hf_model_name, device, "text2image")
103+
result = ov_pipeline.generate(prompt=prompt, num_inference_steps=num_inference_steps, width=size, height=size,
104+
guidance_scale=guidance_scale, rng_seed=seed).data[0]
105+
else:
106+
ov_pipeline = await load_pipeline(hf_model_name, device, "image2image")
107+
# ensure image is square
108+
input_image = utils.crop_center(input_image)
109+
input_image = cv2.resize(input_image, (size, size))
110+
result = ov_pipeline.generate(prompt=prompt, image=ov.Tensor(input_image[None]), num_inference_steps=num_inference_steps, width=size, height=size,
111+
guidance_scale=guidance_scale, strength=1.0 - strength, rng_seed=seed).data[0]
94112
end_time = time.time()
95113

96114
label = safety_checker(Image.fromarray(result), top_k=1)
@@ -113,13 +131,17 @@ async def generate_images(prompt: str, seed: int, size: int, guidance_scale: flo
113131

114132

115133
def build_ui():
116-
examples = [
134+
examples_t2i = [
117135
"A sail boat on a grass field with mountains in the morning and sunny day",
136+
"A beautiful sunset with a sail boat on the ocean, photograph, highly detailed, golden hour, Nikon D850",
118137
"Portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour,"
119-
"Style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
120-
"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
121-
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
122-
"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
138+
"Style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography"
139+
]
140+
141+
examples_i2i = [
142+
"Make me a super hero, 8k",
143+
"Make me a beautiful cyborg with golden hair, 8k",
144+
"Make me an astronaut, cold color palette, muted colors, 8k"
123145
]
124146

125147
with gr.Blocks() as demo:
@@ -132,9 +154,11 @@ def build_ui():
132154
)
133155
with gr.Row():
134156
with gr.Column():
135-
result_img = gr.Image(label="Generated image", elem_id="output_image", format="png")
136157
with gr.Row():
137-
result_time_label = gr.Text("", label="Inference Time", type="text")
158+
input_image = gr.Image(label="Input image (leave blank for text2image generation)", sources=["webcam", "clipboard", "upload"])
159+
result_img = gr.Image(label="Generated image", elem_id="output_image", format="png")
160+
with gr.Row():
161+
result_time_label = gr.Text("", label="Inference time", type="text")
138162
with gr.Row():
139163
start_button = gr.Button("Start generation")
140164
stop_button = gr.Button("Stop generation")
@@ -167,25 +191,43 @@ def build_ui():
167191
step=1,
168192
value=5,
169193
)
194+
with gr.Row():
195+
strength_slider = gr.Slider(
196+
label="Input image influence strength",
197+
minimum=0.0,
198+
maximum=1.0,
199+
step=0.01,
200+
value=0.5
201+
)
202+
size_slider = gr.Slider(
203+
label="Image size",
204+
minimum=256,
205+
maximum=1024,
206+
step=64,
207+
value=512
208+
)
170209

171-
size_slider = gr.Slider(
172-
label="Image size",
173-
minimum=256,
174-
maximum=1024,
175-
step=64,
176-
value=512
177-
)
178-
179210
gr.Examples(
180-
examples=examples,
211+
label="Examples for Text2Image",
212+
examples=examples_t2i,
213+
inputs=prompt_text,
214+
outputs=result_img,
215+
cache_examples=False,
216+
)
217+
218+
gr.Examples(
219+
label="Examples for Image2Image",
220+
examples=examples_i2i,
181221
inputs=prompt_text,
182222
outputs=result_img,
183223
cache_examples=False,
184224
)
225+
185226
# clicking run
186227
gr.on(triggers=[prompt_text.submit, start_button.click],
187228
fn=generate_images,
188-
inputs=[prompt_text, seed_slider, size_slider, guidance_scale_slider, num_inference_steps_slider, randomize_seed_checkbox, device_dropdown, endless_checkbox],
229+
inputs=[input_image, prompt_text, seed_slider, size_slider, guidance_scale_slider, num_inference_steps_slider,
230+
strength_slider, randomize_seed_checkbox, device_dropdown, endless_checkbox],
189231
outputs=[result_img, result_time_label]
190232
)
191233
# clicking stop
@@ -214,7 +256,7 @@ def run_endless_lcm(model_name: str, safety_checker_model: str, local_network: b
214256
parser = argparse.ArgumentParser()
215257
parser.add_argument("--model_name", type=str, default="OpenVINO/LCM_Dreamshaper_v7-fp16-ov",
216258
choices=["OpenVINO/LCM_Dreamshaper_v7-int8-ov", "OpenVINO/LCM_Dreamshaper_v7-fp16-ov"],
217-
help="Visual GenAI model to be used")
259+
help="GenAI model to be used")
218260
parser.add_argument("--safety_checker_model", type=str, default="Falconsai/nsfw_image_detection",
219261
choices=["Falconsai/nsfw_image_detection"], help="The model to verify if the generated image is NSFW")
220262
parser.add_argument("--local_network", action="store_true", help="Whether demo should be available in local network")

demos/paint_your_dreams_demo/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
--extra-index-url https://download.pytorch.org/whl/cpu
22

3-
openvino==2024.6.0
4-
openvino-genai==2024.6.0
3+
openvino==2025.0
4+
openvino-genai==2025.0
55
optimum-intel==1.21.0
66
optimum==1.23.3
77
onnx==1.17.0

demos/utils/demo_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,10 @@ def draw_text(image: ndarray, text: str, point: Tuple[int, int], center: bool =
263263

264264
cv2.putText(image, text=text, org=(text_x, text_y), fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=font_scale * f_width / 2000, color=(0, 0, 0), thickness=2, lineType=cv2.LINE_AA)
265265
cv2.putText(image, text=text, org=(text_x, text_y), fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=font_scale * f_width / 2000, color=font_color, thickness=1, lineType=cv2.LINE_AA)
266+
267+
268+
def crop_center(image: ndarray) -> ndarray:
269+
size = min(image.shape[:2])
270+
start_x = (image.shape[1] - size) // 2
271+
start_y = (image.shape[0] - size) // 2
272+
return image[start_y:start_y + size, start_x:start_x + size]

0 commit comments

Comments
 (0)