11import os
22import gradio as gr
33import requests
4- import base64
4+ import time
55import io
66import hashlib
77from PIL import Image
88import datetime
99import json
10+ import random
1011
1112from fastchat .constants import LOGDIR
1213from fastchat .utils import upload_image_file_to_gcs
14+ from fastchat .serve .gradio_web_server import enable_btn , disable_btn
1315
1416FIREWORKS_API_KEY = os .getenv ("FIREWORKS_API_KEY" )
1517API_BASE = "https://api.fireworks.ai/inference/v1/workflows/accounts/fireworks/models/{model}/text_to_image"
1921 "flux-1-dev-fp8" ,
2022 "flux-1-schnell-fp8" ,
2123 ]
24+ ANONY_NAMES = ["" , "" ]
25+
26+ class State :
27+ def __init__ (self , model_name ):
28+ self .model_name = model_name
29+ self .prompt = ""
30+ self .image_filename = ""
31+ self .generated_image = None
2232
2333def get_conv_log_filename ():
2434 t = datetime .datetime .now ()
2535 conv_log_filename = f"{ t .year } -{ t .month :02d} -{ t .day :02d} -conv.json"
2636 return os .path .join (LOGDIR , f"txt2img-{ conv_log_filename } " )
2737
28- def generate_image (model , prompt ):
38+ def get_battle_pair (models ):
39+ return random .sample (models , 2 )
40+
41+ def add_text (state_left , state_right , prompt ):
42+ if state_left is None or state_right is None :
43+ models = get_battle_pair (DUMMY_MODELS )
44+ state_left = State (models [0 ])
45+ state_right = State (models [1 ])
46+
47+ state_left .prompt = prompt
48+ state_right .prompt = prompt
49+ return state_left , state_right
50+
51+ def generate_image (state ):
2952 """Generate image from text prompt using Fireworks API"""
3053 headers = {
3154 "Authorization" : f"Bearer { FIREWORKS_API_KEY } " ,
@@ -34,13 +57,13 @@ def generate_image(model, prompt):
3457 }
3558
3659 data = {
37- "prompt" : prompt ,
60+ "prompt" : state . prompt ,
3861 "aspect_ratio" : "16:9" ,
3962 "guidance_scale" : 4.5 ,
4063 "num_inference_steps" : 3 ,
4164 }
4265
43- api_url = API_BASE .format (model = model )
66+ api_url = API_BASE .format (model = state . model_name )
4467
4568 try :
4669 response = requests .post (api_url , headers = headers , json = data )
@@ -49,6 +72,7 @@ def generate_image(model, prompt):
4972
5073 image_bytes = response .content
5174 image = Image .open (io .BytesIO (image_bytes ))
75+ state .generated_image = image
5276
5377 except requests .exceptions .RequestException as e :
5478 return f"Error generating image: { str (e )} "
@@ -57,78 +81,131 @@ def generate_image(model, prompt):
5781 image_hash = hashlib .md5 (image .tobytes ()).hexdigest ()
5882 image_filename = f"{ image_hash } .png"
5983 upload_image_file_to_gcs (image , image_filename )
84+ state .image_filename = image_filename
6085
6186 with open (log_filename , "a" ) as f :
6287 data = {
63- "model" : model ,
64- "prompt" : prompt ,
88+ "model" : state . model_name ,
89+ "prompt" : state . prompt ,
6590 "image_filename" : image_filename
6691 }
6792 f .write (json .dumps (data ) + "\n " )
6893
6994 return image
7095
71- def generate_image_multi (model_left , model_right , prompt ):
96+ def generate_image_multi (state_left , state_right ):
97+ # Randomly sample two different models
7298 images = []
73- for model in [model_left , model_right ]:
74- images .append (generate_image (model , prompt ))
99+ states = [state_left , state_right ]
100+ for i in range (2 ):
101+ images .append (generate_image (states [i ]))
75102
76103 return images
77104
105+ def flash_buttons ():
106+ btn_updates = [
107+ [disable_btn ] * 4 ,
108+ [enable_btn ] * 4 ,
109+ ]
110+ for i in range (4 ):
111+ yield btn_updates [i % 2 ]
112+ time .sleep (0.3 )
113+
114+ def reveal_models (state_left , state_right ):
115+ return [f"Model A: { state_left .model_name } " , f"Model B: { state_right .model_name } " ]
78116
79117# Create Gradio interface
80118with gr .Blocks (title = "Text to Image Generator" ) as demo :
81119 gr .Markdown ("# Text to Image Generator" )
82120 gr .Markdown ("Enter a text prompt to generate an image" )
83121
84-
85122 num_sides = 2
86123 model_selectors = [None ] * num_sides
124+ states = [gr .State () for _ in range (num_sides )]
87125
88- with gr .Column ():
126+ with gr .Column ():
89127 with gr .Group ():
128+ with gr .Row ():
129+ output_left = gr .Image (
130+ type = "pil" ,
131+ show_label = False
132+ )
133+ output_right = gr .Image (
134+ type = "pil" ,
135+ show_label = False
136+ )
137+
90138 with gr .Row ():
91139 for i in range (num_sides ):
92- model_selectors [i ] = gr .Dropdown (
93- choices = DUMMY_MODELS ,
94- value = DUMMY_MODELS [i ] if DUMMY_MODELS else "" ,
95- interactive = True ,
140+ model_selectors [i ] = gr .Markdown (
141+ value = ANONY_NAMES [i ],
96142 show_label = False ,
97- container = False ,
98- )
99-
100-
101- with gr .Group ():
102- with gr .Row ():
103- output_left = gr .Image (
104- type = "pil" ,
105- show_label = False
106- )
107- output_right = gr .Image (
108- type = "pil" ,
109- show_label = False
110143 )
111144
145+ with gr .Row ():
146+ left_btn = gr .Button (
147+ value = "Left" ,
148+ interactive = False ,
149+ visible = False
150+ )
151+ tie_btn = gr .Button (
152+ value = "Tie" ,
153+ interactive = False ,
154+ visible = False
155+ )
156+ right_btn = gr .Button (
157+ value = "Right" ,
158+ interactive = False ,
159+ visible = False
160+ )
161+ idk_btn = gr .Button (
162+ value = "IDK" ,
163+ interactive = False ,
164+ visible = False
165+ )
166+
112167 with gr .Row ():
113168 text_input = gr .Textbox (
114169 label = "Prompt" ,
115170 placeholder = "Enter your prompt here..." ,
116171 show_label = False
117172 )
118173 send_btn = gr .Button ("Generate" , variant = "primary" )
119-
174+
120175 # Handle generation
121- send_btn .click (
176+ gen_output = send_btn .click (
177+ fn = add_text ,
178+ inputs = states + [text_input ],
179+ outputs = states
180+ ).then (
122181 fn = generate_image_multi ,
123- inputs = model_selectors + [ text_input ] ,
182+ inputs = states ,
124183 outputs = [output_left , output_right ]
184+ ).then (
185+ fn = flash_buttons ,
186+ outputs = [left_btn , tie_btn , right_btn , idk_btn ]
125187 )
126188
127189 text_input .submit (
190+ fn = add_text ,
191+ inputs = states + [text_input ],
192+ outputs = states
193+ ).then (
128194 fn = generate_image_multi ,
129- inputs = model_selectors + [ text_input ] ,
195+ inputs = states ,
130196 outputs = [output_left , output_right ]
197+ ).then (
198+ fn = flash_buttons ,
199+ outputs = [left_btn , tie_btn , right_btn , idk_btn ]
131200 )
132201
202+ # Handle voting buttons
203+ for btn in [left_btn , tie_btn , right_btn , idk_btn ]:
204+ btn .click (
205+ fn = reveal_models ,
206+ inputs = states ,
207+ outputs = model_selectors
208+ )
209+
133210if __name__ == "__main__" :
134211 demo .launch (share = True )
0 commit comments