Skip to content

Commit b2c110b

Browse files
committed
show model when vote
1 parent 59207c1 commit b2c110b

File tree

1 file changed

+109
-32
lines changed

1 file changed

+109
-32
lines changed
Lines changed: 109 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
import os
22
import gradio as gr
33
import requests
4-
import base64
4+
import time
55
import io
66
import hashlib
77
from PIL import Image
88
import datetime
99
import json
10+
import random
1011

1112
from fastchat.constants import LOGDIR
1213
from fastchat.utils import upload_image_file_to_gcs
14+
from fastchat.serve.gradio_web_server import enable_btn, disable_btn
1315

1416
FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY")
1517
API_BASE = "https://api.fireworks.ai/inference/v1/workflows/accounts/fireworks/models/{model}/text_to_image"
@@ -19,13 +21,34 @@
1921
"flux-1-dev-fp8",
2022
"flux-1-schnell-fp8",
2123
]
24+
ANONY_NAMES = ["", ""]
25+
26+
class State:
27+
def __init__(self, model_name):
28+
self.model_name = model_name
29+
self.prompt = ""
30+
self.image_filename = ""
31+
self.generated_image = None
2232

2333
def get_conv_log_filename():
2434
t = datetime.datetime.now()
2535
conv_log_filename = f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json"
2636
return os.path.join(LOGDIR, f"txt2img-{conv_log_filename}")
2737

28-
def generate_image(model, prompt):
38+
def get_battle_pair(models):
39+
return random.sample(models, 2)
40+
41+
def add_text(state_left, state_right, prompt):
42+
if state_left is None or state_right is None:
43+
models = get_battle_pair(DUMMY_MODELS)
44+
state_left = State(models[0])
45+
state_right = State(models[1])
46+
47+
state_left.prompt = prompt
48+
state_right.prompt = prompt
49+
return state_left, state_right
50+
51+
def generate_image(state):
2952
"""Generate image from text prompt using Fireworks API"""
3053
headers = {
3154
"Authorization": f"Bearer {FIREWORKS_API_KEY}",
@@ -34,13 +57,13 @@ def generate_image(model, prompt):
3457
}
3558

3659
data = {
37-
"prompt": prompt,
60+
"prompt": state.prompt,
3861
"aspect_ratio": "16:9",
3962
"guidance_scale": 4.5,
4063
"num_inference_steps": 3,
4164
}
4265

43-
api_url = API_BASE.format(model=model)
66+
api_url = API_BASE.format(model=state.model_name)
4467

4568
try:
4669
response = requests.post(api_url, headers=headers, json=data)
@@ -49,6 +72,7 @@ def generate_image(model, prompt):
4972

5073
image_bytes = response.content
5174
image = Image.open(io.BytesIO(image_bytes))
75+
state.generated_image = image
5276

5377
except requests.exceptions.RequestException as e:
5478
return f"Error generating image: {str(e)}"
@@ -57,78 +81,131 @@ def generate_image(model, prompt):
5781
image_hash = hashlib.md5(image.tobytes()).hexdigest()
5882
image_filename = f"{image_hash}.png"
5983
upload_image_file_to_gcs(image, image_filename)
84+
state.image_filename = image_filename
6085

6186
with open(log_filename, "a") as f:
6287
data = {
63-
"model": model,
64-
"prompt": prompt,
88+
"model": state.model_name,
89+
"prompt": state.prompt,
6590
"image_filename": image_filename
6691
}
6792
f.write(json.dumps(data) + "\n")
6893

6994
return image
7095

71-
def generate_image_multi(model_left, model_right, prompt):
96+
def generate_image_multi(state_left, state_right):
97+
# Randomly sample two different models
7298
images = []
73-
for model in [model_left, model_right]:
74-
images.append(generate_image(model, prompt))
99+
states = [state_left, state_right]
100+
for i in range(2):
101+
images.append(generate_image(states[i]))
75102

76103
return images
77104

105+
def flash_buttons():
106+
btn_updates = [
107+
[disable_btn] * 4,
108+
[enable_btn] * 4,
109+
]
110+
for i in range(4):
111+
yield btn_updates[i % 2]
112+
time.sleep(0.3)
113+
114+
def reveal_models(state_left, state_right):
115+
return [f"Model A: {state_left.model_name}", f"Model B: {state_right.model_name}"]
78116

79117
# Create Gradio interface
80118
with gr.Blocks(title="Text to Image Generator") as demo:
81119
gr.Markdown("# Text to Image Generator")
82120
gr.Markdown("Enter a text prompt to generate an image")
83121

84-
85122
num_sides = 2
86123
model_selectors = [None] * num_sides
124+
states = [gr.State() for _ in range(num_sides)]
87125

88-
with gr.Column():
126+
with gr.Column():
89127
with gr.Group():
128+
with gr.Row():
129+
output_left = gr.Image(
130+
type="pil",
131+
show_label=False
132+
)
133+
output_right = gr.Image(
134+
type="pil",
135+
show_label=False
136+
)
137+
90138
with gr.Row():
91139
for i in range(num_sides):
92-
model_selectors[i] = gr.Dropdown(
93-
choices=DUMMY_MODELS,
94-
value=DUMMY_MODELS[i] if DUMMY_MODELS else "",
95-
interactive=True,
140+
model_selectors[i] = gr.Markdown(
141+
value=ANONY_NAMES[i],
96142
show_label=False,
97-
container=False,
98-
)
99-
100-
101-
with gr.Group():
102-
with gr.Row():
103-
output_left = gr.Image(
104-
type="pil",
105-
show_label=False
106-
)
107-
output_right = gr.Image(
108-
type="pil",
109-
show_label=False
110143
)
111144

145+
with gr.Row():
146+
left_btn = gr.Button(
147+
value="Left",
148+
interactive=False,
149+
visible=False
150+
)
151+
tie_btn = gr.Button(
152+
value="Tie",
153+
interactive=False,
154+
visible=False
155+
)
156+
right_btn = gr.Button(
157+
value="Right",
158+
interactive=False,
159+
visible=False
160+
)
161+
idk_btn = gr.Button(
162+
value="IDK",
163+
interactive=False,
164+
visible=False
165+
)
166+
112167
with gr.Row():
113168
text_input = gr.Textbox(
114169
label="Prompt",
115170
placeholder="Enter your prompt here...",
116171
show_label=False
117172
)
118173
send_btn = gr.Button("Generate", variant="primary")
119-
174+
120175
# Handle generation
121-
send_btn.click(
176+
gen_output = send_btn.click(
177+
fn=add_text,
178+
inputs=states + [text_input],
179+
outputs=states
180+
).then(
122181
fn=generate_image_multi,
123-
inputs=model_selectors + [text_input],
182+
inputs=states,
124183
outputs=[output_left, output_right]
184+
).then(
185+
fn=flash_buttons,
186+
outputs=[left_btn, tie_btn, right_btn, idk_btn]
125187
)
126188

127189
text_input.submit(
190+
fn=add_text,
191+
inputs=states + [text_input],
192+
outputs=states
193+
).then(
128194
fn=generate_image_multi,
129-
inputs=model_selectors + [text_input],
195+
inputs=states,
130196
outputs=[output_left, output_right]
197+
).then(
198+
fn=flash_buttons,
199+
outputs=[left_btn, tie_btn, right_btn, idk_btn]
131200
)
132201

202+
# Handle voting buttons
203+
for btn in [left_btn, tie_btn, right_btn, idk_btn]:
204+
btn.click(
205+
fn=reveal_models,
206+
inputs=states,
207+
outputs=model_selectors
208+
)
209+
133210
if __name__ == "__main__":
134211
demo.launch(share=True)

0 commit comments

Comments
 (0)