Skip to content

Commit 59207c1

Browse files
committed
add simple side-by-side
1 parent 79a1722 commit 59207c1

File tree

1 file changed

+68
-21
lines changed

1 file changed

+68
-21
lines changed

fastchat/serve/gradio_block_arena_txt2img.py

Lines changed: 68 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,29 @@
33
import requests
44
import base64
55
import io
6+
import hashlib
67
from PIL import Image
8+
import datetime
9+
import json
10+
11+
from fastchat.constants import LOGDIR
12+
from fastchat.utils import upload_image_file_to_gcs
713

814
FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY")
9-
API_BASE = "https://api.fireworks.ai/inference/v1/workflows/accounts/fireworks/models/{}/text_to_image"
15+
API_BASE = "https://api.fireworks.ai/inference/v1/workflows/accounts/fireworks/models/{model}/text_to_image"
1016
DUMMY_MODELS = ["stable-diffusion-3p5-medium",
1117
"stable-diffusion-3p5-large",
1218
"stable-diffusion-3p5-large-turbo",
1319
"flux-1-dev-fp8",
1420
"flux-1-schnell-fp8",
1521
]
1622

17-
def generate_image(prompt, model):
23+
def get_conv_log_filename():
24+
t = datetime.datetime.now()
25+
conv_log_filename = f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json"
26+
return os.path.join(LOGDIR, f"txt2img-{conv_log_filename}")
27+
28+
def generate_image(model, prompt):
1829
"""Generate image from text prompt using Fireworks API"""
1930
headers = {
2031
"Authorization": f"Bearer {FIREWORKS_API_KEY}",
@@ -38,29 +49,65 @@ def generate_image(prompt, model):
3849

3950
image_bytes = response.content
4051
image = Image.open(io.BytesIO(image_bytes))
41-
42-
return image
43-
52+
4453
except requests.exceptions.RequestException as e:
4554
return f"Error generating image: {str(e)}"
4655

56+
log_filename = get_conv_log_filename()
57+
image_hash = hashlib.md5(image.tobytes()).hexdigest()
58+
image_filename = f"{image_hash}.png"
59+
upload_image_file_to_gcs(image, image_filename)
60+
61+
with open(log_filename, "a") as f:
62+
data = {
63+
"model": model,
64+
"prompt": prompt,
65+
"image_filename": image_filename
66+
}
67+
f.write(json.dumps(data) + "\n")
68+
69+
return image
70+
71+
def generate_image_multi(model_left, model_right, prompt):
72+
images = []
73+
for model in [model_left, model_right]:
74+
images.append(generate_image(model, prompt))
75+
76+
return images
77+
78+
4779
# Create Gradio interface
4880
with gr.Blocks(title="Text to Image Generator") as demo:
4981
gr.Markdown("# Text to Image Generator")
5082
gr.Markdown("Enter a text prompt to generate an image")
5183

84+
85+
num_sides = 2
86+
model_selectors = [None] * num_sides
87+
5288
with gr.Column():
5389
with gr.Group():
54-
model_selector = gr.Dropdown(
55-
choices=DUMMY_MODELS,
56-
interactive=True,
57-
show_label=False,
58-
container=False,
59-
)
60-
image_output = gr.Image(
61-
label="Generated Image",
62-
type="pil"
63-
)
90+
with gr.Row():
91+
for i in range(num_sides):
92+
model_selectors[i] = gr.Dropdown(
93+
choices=DUMMY_MODELS,
94+
value=DUMMY_MODELS[i] if DUMMY_MODELS else "",
95+
interactive=True,
96+
show_label=False,
97+
container=False,
98+
)
99+
100+
101+
with gr.Group():
102+
with gr.Row():
103+
output_left = gr.Image(
104+
type="pil",
105+
show_label=False
106+
)
107+
output_right = gr.Image(
108+
type="pil",
109+
show_label=False
110+
)
64111

65112
with gr.Row():
66113
text_input = gr.Textbox(
@@ -72,15 +119,15 @@ def generate_image(prompt, model):
72119

73120
# Handle generation
74121
send_btn.click(
75-
fn=generate_image,
76-
inputs=[text_input, model_selector],
77-
outputs=image_output
122+
fn=generate_image_multi,
123+
inputs=model_selectors + [text_input],
124+
outputs=[output_left, output_right]
78125
)
79126

80127
text_input.submit(
81-
fn=generate_image,
82-
inputs=[text_input, model_selector],
83-
outputs=image_output
128+
fn=generate_image_multi,
129+
inputs=model_selectors + [text_input],
130+
outputs=[output_left, output_right]
84131
)
85132

86133
if __name__ == "__main__":

0 commit comments

Comments
 (0)