@@ -89,6 +89,95 @@ def chat_completion_openai(model, messages, temperature, max_tokens, api_dict=No
8989 return output
9090
9191
92+ def chat_completion_anthropic (model , messages , temperature , max_tokens , api_dict = None ):
93+ import anthropic
94+
95+ if api_dict :
96+ api_key = api_dict ["api_key" ]
97+ else :
98+ api_key = os .environ ["ANTHROPIC_API_KEY" ]
99+
100+ sys_msg = ""
101+ if messages [0 ]["role" ] == "system" :
102+ sys_msg = messages [0 ]["content" ]
103+ messages = messages [1 :]
104+
105+ output = API_ERROR_OUTPUT
106+ for _ in range (API_MAX_RETRY ):
107+ try :
108+ c = anthropic .Anthropic (api_key = api_key )
109+ response = c .messages .create (
110+ model = model ,
111+ messages = messages ,
112+ stop_sequences = [anthropic .HUMAN_PROMPT ],
113+ max_tokens = max_tokens ,
114+ temperature = temperature ,
115+ system = sys_msg ,
116+ )
117+ output = response .content [0 ].text
118+ break
119+ except anthropic .APIError as e :
120+ print (type (e ), e )
121+ time .sleep (API_RETRY_SLEEP )
122+ return output
123+
124+
125+ def chat_completion_gemini (
126+ model , messages , temperature , max_tokens , api_dict = None , image_path = None
127+ ):
128+ import google
129+ import google .generativeai as genai
130+ from google .generativeai .types import HarmCategory , HarmBlockThreshold
131+ from PIL import Image
132+
133+ if api_dict :
134+ api_key = api_dict ["api_key" ]
135+ genai .configure (api_key = api_key )
136+ else :
137+ genai .configure (api_key = os .environ ["GENAI_API_KEY" ])
138+
139+ sys_msg = ""
140+ if messages [0 ]["role" ] == "system" :
141+ sys_msg = messages [0 ]["content" ]
142+ messages = messages [1 :]
143+
144+ prompt = messages [0 ]["content" ]
145+ if type (prompt ) == list :
146+ prompt = [prompt [0 ]["text" ], Image .open (image_path ).convert ("RGB" )]
147+
148+ safety_settings = {
149+ HarmCategory .HARM_CATEGORY_HATE_SPEECH : HarmBlockThreshold .BLOCK_NONE ,
150+ HarmCategory .HARM_CATEGORY_HARASSMENT : HarmBlockThreshold .BLOCK_NONE ,
151+ HarmCategory .HARM_CATEGORY_DANGEROUS_CONTENT : HarmBlockThreshold .BLOCK_NONE ,
152+ HarmCategory .HARM_CATEGORY_SEXUALLY_EXPLICIT : HarmBlockThreshold .BLOCK_NONE ,
153+ HarmCategory .HARM_CATEGORY_DANGEROUS_CONTENT : HarmBlockThreshold .BLOCK_NONE ,
154+ }
155+ output = API_ERROR_OUTPUT
156+ for _ in range (API_MAX_RETRY ):
157+ try :
158+ gemini = genai .GenerativeModel (model , system_instruction = sys_msg )
159+ gemini .max_output_tokens = max_tokens
160+ gemini .temperature = temperature
161+ response = gemini .generate_content (prompt , safety_settings = safety_settings )
162+ if response .candidates [0 ].finish_reason != 1 :
163+ print (
164+ f"Gemini did not finish generating content: { response .candidates [0 ].finish_reason } "
165+ )
166+ output = "Gemini did not finish generating content"
167+ else :
168+ output = response .text
169+ break
170+ except google .api_core .exceptions .ResourceExhausted as e :
171+ # THIS IS A TEMPORARY FIX
172+ print (type (e ), e )
173+ time .sleep (API_RETRY_SLEEP )
174+ except Exception as e :
175+ # THIS IS A TEMPORARY FIX
176+ print (type (e ), e )
177+ time .sleep (API_RETRY_SLEEP )
178+ return output
179+
180+
92181def get_answer (
93182 question : dict ,
94183 model_name : str ,
@@ -98,6 +187,7 @@ def get_answer(
98187 api_dict : dict ,
99188 categories : list ,
100189 testing : bool ,
190+ api_type : str ,
101191):
102192 if "category_tag" in question :
103193 category_tag = question ["category_tag" ]
@@ -107,14 +197,34 @@ def get_answer(
107197 output_log = {}
108198
109199 for category in categories :
110- conv = category .pre_process (question ["prompt" ])
111- output = chat_completion_openai (
112- model = model_name ,
113- messages = conv ,
114- temperature = temperature ,
115- max_tokens = max_tokens ,
116- api_dict = api_dict ,
117- )
200+ conv = category .pre_process (question )
201+ if api_type == "openai" :
202+ output = chat_completion_openai (
203+ model = model_name ,
204+ messages = conv ,
205+ temperature = temperature ,
206+ max_tokens = max_tokens ,
207+ api_dict = api_dict ,
208+ )
209+ elif api_type == "anthropic" :
210+ output = chat_completion_anthropic (
211+ model = model_name ,
212+ messages = conv ,
213+ temperature = temperature ,
214+ max_tokens = max_tokens ,
215+ api_dict = api_dict ,
216+ )
217+ elif api_type == "gemini" :
218+ output = chat_completion_gemini (
219+ model = model_name ,
220+ messages = conv ,
221+ temperature = temperature ,
222+ max_tokens = max_tokens ,
223+ api_dict = api_dict ,
224+ image_path = question .get ("image_path" ),
225+ )
226+ else :
227+ raise ValueError (f"api_type { api_type } not supported" )
118228 # Dump answers
119229 category_tag [category .name_tag ] = category .post_process (output )
120230
@@ -169,6 +279,7 @@ def find_required_tasks(row):
169279 parser = argparse .ArgumentParser ()
170280 parser .add_argument ("--config" , type = str , required = True )
171281 parser .add_argument ("--testing" , action = "store_true" )
282+ parser .add_argument ("--vision" , action = "store_true" )
172283 args = parser .parse_args ()
173284
174285 enter = input (
@@ -199,6 +310,15 @@ def find_required_tasks(row):
199310 assert len (input_data ) == len (input_data .uid .unique ())
200311 print (f"{ len (input_data )} # of input data just loaded" )
201312
313+ if args .vision :
314+ old_len = len (input_data )
315+ input_data ["image_hash" ] = input_data .conversation_a .map (
316+ lambda convo : convo [0 ]["content" ][1 ][0 ]
317+ )
318+ input_data ["image_path" ] = input_data .image_hash .map (
319+ lambda x : f"{ config ['image_dir' ]} /{ x } .png"
320+ )
321+
202322 if config ["cache_file" ]:
203323 print ("loading cache data" )
204324 with open (config ["cache_file" ], "rb" ) as f :
@@ -246,9 +366,18 @@ def find_required_tasks(row):
246366 f"{ name } : { len (not_labeled [not_labeled .required_tasks .map (lambda tasks : name in tasks )])} "
247367 )
248368
249- not_labeled ["prompt" ] = not_labeled .conversation_a .map (
250- lambda convo : "\n " .join ([convo [i ]["content" ] for i in range (0 , len (convo ), 2 )])
251- )
369+ if args .vision :
370+ not_labeled ["prompt" ] = not_labeled .conversation_a .map (
371+ lambda convo : "\n " .join (
372+ [convo [i ]["content" ][0 ] for i in range (0 , len (convo ), 2 )]
373+ )
374+ )
375+ else :
376+ not_labeled ["prompt" ] = not_labeled .conversation_a .map (
377+ lambda convo : "\n " .join (
378+ [convo [i ]["content" ] for i in range (0 , len (convo ), 2 )]
379+ )
380+ )
252381 not_labeled ["prompt" ] = not_labeled .prompt .map (lambda x : x [:12500 ])
253382
254383 with concurrent .futures .ThreadPoolExecutor (
@@ -270,6 +399,7 @@ def find_required_tasks(row):
270399 if category .name_tag in row ["required_tasks" ]
271400 ],
272401 args .testing ,
402+ config ["api_type" ],
273403 )
274404 futures .append (future )
275405 for future in tqdm .tqdm (
0 commit comments