33import requests
44import base64
55import io
6+ import hashlib
67from PIL import Image
8+ import datetime
9+ import json
10+
11+ from fastchat .constants import LOGDIR
12+ from fastchat .utils import upload_image_file_to_gcs
713
814FIREWORKS_API_KEY = os .getenv ("FIREWORKS_API_KEY" )
9- API_BASE = "https://api.fireworks.ai/inference/v1/workflows/accounts/fireworks/models/{}/text_to_image"
15+ API_BASE = "https://api.fireworks.ai/inference/v1/workflows/accounts/fireworks/models/{model }/text_to_image"
1016DUMMY_MODELS = ["stable-diffusion-3p5-medium" ,
1117 "stable-diffusion-3p5-large" ,
1218 "stable-diffusion-3p5-large-turbo" ,
1319 "flux-1-dev-fp8" ,
1420 "flux-1-schnell-fp8" ,
1521 ]
1622
17- def generate_image (prompt , model ):
23+ def get_conv_log_filename ():
24+ t = datetime .datetime .now ()
25+ conv_log_filename = f"{ t .year } -{ t .month :02d} -{ t .day :02d} -conv.json"
26+ return os .path .join (LOGDIR , f"txt2img-{ conv_log_filename } " )
27+
28+ def generate_image (model , prompt ):
1829 """Generate image from text prompt using Fireworks API"""
1930 headers = {
2031 "Authorization" : f"Bearer { FIREWORKS_API_KEY } " ,
@@ -38,29 +49,65 @@ def generate_image(prompt, model):
3849
3950 image_bytes = response .content
4051 image = Image .open (io .BytesIO (image_bytes ))
41-
42- return image
43-
52+
4453 except requests .exceptions .RequestException as e :
4554 return f"Error generating image: { str (e )} "
4655
56+ log_filename = get_conv_log_filename ()
57+ image_hash = hashlib .md5 (image .tobytes ()).hexdigest ()
58+ image_filename = f"{ image_hash } .png"
59+ upload_image_file_to_gcs (image , image_filename )
60+
61+ with open (log_filename , "a" ) as f :
62+ data = {
63+ "model" : model ,
64+ "prompt" : prompt ,
65+ "image_filename" : image_filename
66+ }
67+ f .write (json .dumps (data ) + "\n " )
68+
69+ return image
70+
71+ def generate_image_multi (model_left , model_right , prompt ):
72+ images = []
73+ for model in [model_left , model_right ]:
74+ images .append (generate_image (model , prompt ))
75+
76+ return images
77+
78+
4779# Create Gradio interface
4880with gr .Blocks (title = "Text to Image Generator" ) as demo :
4981 gr .Markdown ("# Text to Image Generator" )
5082 gr .Markdown ("Enter a text prompt to generate an image" )
5183
84+
85+ num_sides = 2
86+ model_selectors = [None ] * num_sides
87+
5288 with gr .Column ():
5389 with gr .Group ():
54- model_selector = gr .Dropdown (
55- choices = DUMMY_MODELS ,
56- interactive = True ,
57- show_label = False ,
58- container = False ,
59- )
60- image_output = gr .Image (
61- label = "Generated Image" ,
62- type = "pil"
63- )
90+ with gr .Row ():
91+ for i in range (num_sides ):
92+ model_selectors [i ] = gr .Dropdown (
93+ choices = DUMMY_MODELS ,
94+ value = DUMMY_MODELS [i ] if DUMMY_MODELS else "" ,
95+ interactive = True ,
96+ show_label = False ,
97+ container = False ,
98+ )
99+
100+
101+ with gr .Group ():
102+ with gr .Row ():
103+ output_left = gr .Image (
104+ type = "pil" ,
105+ show_label = False
106+ )
107+ output_right = gr .Image (
108+ type = "pil" ,
109+ show_label = False
110+ )
64111
65112 with gr .Row ():
66113 text_input = gr .Textbox (
@@ -72,15 +119,15 @@ def generate_image(prompt, model):
72119
73120 # Handle generation
74121 send_btn .click (
75- fn = generate_image ,
76- inputs = [text_input , model_selector ],
77- outputs = image_output
122+ fn = generate_image_multi ,
123+ inputs = model_selectors + [text_input ],
124+ outputs = [ output_left , output_right ]
78125 )
79126
80127 text_input .submit (
81- fn = generate_image ,
82- inputs = [text_input , model_selector ],
83- outputs = image_output
128+ fn = generate_image_multi ,
129+ inputs = model_selectors + [text_input ],
130+ outputs = [ output_left , output_right ]
84131 )
85132
86133if __name__ == "__main__" :
0 commit comments