Skip to content

Commit bb57382

Browse files
committed
damn it i forgot to run lint
1 parent 61444f6 commit bb57382

File tree

2 files changed

+70
-51
lines changed

2 files changed

+70
-51
lines changed

fastchat/serve/monitor/classify/category.py

Lines changed: 49 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def get_score(self, judgment):
7979

8080
def pre_process(self, prompt):
8181
conv = [{"role": "system", "content": self.sys_prompt}]
82-
conv.append({"role": "user", "content": prompt['prompt']})
82+
conv.append({"role": "user", "content": prompt["prompt"]})
8383
return conv
8484

8585
def post_process(self, judgment):
@@ -106,7 +106,7 @@ def get_score(self, judgment):
106106
return None
107107

108108
def pre_process(self, prompt):
109-
args = {"PROMPT": prompt['prompt']}
109+
args = {"PROMPT": prompt["prompt"]}
110110
conv = [
111111
{"role": "system", "content": self.system_prompt},
112112
{"role": "user", "content": self.prompt_template.format(**args)},
@@ -140,7 +140,7 @@ def get_score(self, judgment):
140140
return None
141141

142142
def pre_process(self, prompt):
143-
args = {"PROMPT": prompt['prompt']}
143+
args = {"PROMPT": prompt["prompt"]}
144144
conv = [
145145
{"role": "system", "content": self.system_prompt},
146146
{"role": "user", "content": self.prompt_template.format(**args)},
@@ -177,7 +177,7 @@ def get_score(self, judgment):
177177
return None
178178

179179
def pre_process(self, prompt):
180-
args = {"PROMPT": prompt['prompt']}
180+
args = {"PROMPT": prompt["prompt"]}
181181
conv = [
182182
{"role": "system", "content": self.system_prompt},
183183
{"role": "user", "content": self.prompt_template.format(**args)},
@@ -188,20 +188,19 @@ def post_process(self, judgment):
188188
score = self.get_score(judgment=judgment)
189189
bool_score = bool(score == "yes") if score else False
190190
return {"creative_writing": bool_score, "score": score}
191-
191+
192+
192193
#####################
193194
# Vision Categories #
194195
#####################
195196
class CategoryCaptioning(Category):
196-
197197
def __init__(self):
198198
super().__init__()
199199
self.name_tag = "captioning_v0.1"
200200
self.pattern = re.compile(r"<decision>(\w+)<\/decision>")
201201
self.system_prompt = "You are tasked with determining if a given VQA question is a captioning question. A captioning question asks for a general, overall description of the entire image. It must be a single, open-ended query that does NOT ask about particular objects, people, or parts of the image, nor require interpretation beyond a broad description of what is visually present. Examples include 'What is happening in this image?', 'Describe this picture.', 'Explain', etc. An example of a non-captioning question is 'Describe what is funny in this picture.' because it asks for a specific interpretation of the image content. \n\nOutput your verdict in the following format:<decision>\n[yes/no]\n</decision>. Do NOT explain."
202202
self.prompt_template = "<user_prompt>\n{PROMPT}\n</user_prompt>"
203203

204-
205204
def get_score(self, judgment):
206205
matches = self.pattern.findall(judgment.replace("\n", "").lower())
207206
matches = [m for m in matches if m != ""]
@@ -212,7 +211,7 @@ def get_score(self, judgment):
212211
else:
213212
return None
214213

215-
def pre_process(self, prompt, api_type='openai'):
214+
def pre_process(self, prompt, api_type="openai"):
216215
args = {"PROMPT": prompt["prompt"]}
217216
conv = [
218217
{"role": "system", "content": self.system_prompt},
@@ -223,7 +222,8 @@ def pre_process(self, prompt, api_type='openai'):
223222
def post_process(self, judgment):
224223
score = self.get_score(judgment=judgment)
225224
return {"captioning": bool(score == "yes") if score else False}
226-
225+
226+
227227
class CategoryCreativeWritingVision(Category):
228228
def __init__(self):
229229
super().__init__()
@@ -248,8 +248,8 @@ def get_score(self, judgment):
248248
else:
249249
return None
250250

251-
def pre_process(self, prompt, api_type='openai'):
252-
args = {"PROMPT": prompt['prompt']}
251+
def pre_process(self, prompt, api_type="openai"):
252+
args = {"PROMPT": prompt["prompt"]}
253253
conv = [
254254
{"role": "system", "content": self.system_prompt},
255255
{"role": "user", "content": self.prompt_template.format(**args)},
@@ -260,16 +260,16 @@ def post_process(self, judgment):
260260
score = self.get_score(judgment=judgment)
261261
bool_score = bool(score == "yes") if score else False
262262
return {"creative_writing": bool_score, "score": score}
263-
264-
class CategoryEntityRecognition(Category):
265263

264+
265+
class CategoryEntityRecognition(Category):
266266
def __init__(self):
267267
super().__init__()
268268
self.name_tag = "entity_recognition_v0.1"
269269
self.pattern = re.compile(r"<decision>(\w+)<\/decision>")
270270
self.system_prompt = "You are tasked with determining if a given VQA question is an entity recognition question. An entity recognition question asks for the identification of specific objects or people in the image. This does NOT include questions that ask for a general description of the image, questions that only ask for object counts, or questions that only require reading text in the image.\n\nOutput your verdict in the following format:<decision>\n[yes/no]\n</decision>. Do NOT explain."
271271
self.prompt_template = "<user_prompt>\n{PROMPT}\n</user_prompt>"
272-
272+
273273
def get_score(self, judgment):
274274
matches = self.pattern.findall(judgment.replace("\n", "").lower())
275275
matches = [m for m in matches if m != ""]
@@ -280,7 +280,7 @@ def get_score(self, judgment):
280280
else:
281281
return None
282282

283-
def pre_process(self, prompt, api_type='openai'):
283+
def pre_process(self, prompt, api_type="openai"):
284284
args = {"PROMPT": prompt["prompt"]}
285285
conv = [
286286
{"role": "system", "content": self.system_prompt},
@@ -291,19 +291,22 @@ def pre_process(self, prompt, api_type='openai'):
291291
def post_process(self, judgment):
292292
score = self.get_score(judgment=judgment)
293293
return {"entity_recognition": bool(score == "yes") if score else False}
294-
294+
295+
295296
import base64
296297
import io
297298
from PIL import Image
299+
300+
298301
def pil_to_base64(image_path):
299302
image = Image.open(image_path)
300303
buffered = io.BytesIO()
301304
image.save(buffered, format="PNG")
302305
img_str = base64.b64encode(buffered.getvalue()).decode()
303306
return img_str
304307

305-
class CategoryOpticalCharacterRecognition(Category):
306308

309+
class CategoryOpticalCharacterRecognition(Category):
307310
def __init__(self):
308311
super().__init__()
309312
self.name_tag = "ocr_v0.1"
@@ -321,21 +324,23 @@ def get_score(self, judgment):
321324
else:
322325
return None
323326

324-
def pre_process(self, prompt, api_type='openai'):
327+
def pre_process(self, prompt, api_type="openai"):
325328
args = {"PROMPT": prompt["prompt"]}
326329
base64_image = pil_to_base64(prompt["image_path"])
327-
if api_type == 'anthropic':
330+
if api_type == "anthropic":
328331
conv = [
329332
{"role": "system", "content": self.system_prompt},
330333
{
331334
"role": "user",
332335
"content": [
333-
{
336+
{
334337
"type": "image",
335338
"source": {
336339
"type": "base64",
337340
"media_type": "image/jpeg",
338-
"data": base64.b64encode(prompt["image_path"].content).decode("utf-8"),
341+
"data": base64.b64encode(
342+
prompt["image_path"].content
343+
).decode("utf-8"),
339344
},
340345
},
341346
{"type": "text", "text": self.prompt_template.format(**args)},
@@ -363,9 +368,9 @@ def pre_process(self, prompt, api_type='openai'):
363368
def post_process(self, judgment):
364369
score = self.get_score(judgment=judgment)
365370
return {"ocr": bool(score == "yes") if score else False}
366-
367-
class CategoryHumor(Category):
368371

372+
373+
class CategoryHumor(Category):
369374
def __init__(self):
370375
super().__init__()
371376
self.name_tag = "humor_v0.1"
@@ -382,17 +387,17 @@ def get_score(self, judgment):
382387
return matches[0]
383388
else:
384389
return None
385-
386-
def pre_process(self, prompt, api_type='openai'):
390+
391+
def pre_process(self, prompt, api_type="openai"):
387392
args = {"PROMPT": prompt["prompt"]}
388393
base64_image = pil_to_base64(prompt["image_path"])
389-
if api_type == 'anthropic':
394+
if api_type == "anthropic":
390395
conv = [
391396
{"role": "system", "content": self.system_prompt},
392397
{
393398
"role": "user",
394399
"content": [
395-
{
400+
{
396401
"type": "image",
397402
"source": {
398403
"type": "base64",
@@ -421,14 +426,16 @@ def pre_process(self, prompt, api_type='openai'):
421426
},
422427
]
423428
return conv
424-
429+
425430
def post_process(self, judgment):
426431
score = self.get_score(judgment=judgment)
427432
return {"humor": bool(score == "yes") if score else False}
428-
433+
434+
429435
import os
430-
class CategoryHomework(Category):
431436

437+
438+
class CategoryHomework(Category):
432439
def __init__(self):
433440
super().__init__()
434441
self.name_tag = "homework_v0.1"
@@ -449,21 +456,21 @@ def get_score(self, judgment):
449456
return matches[0]
450457
else:
451458
return None
452-
453-
def pre_process(self, prompt, api_type='openai'):
459+
460+
def pre_process(self, prompt, api_type="openai"):
454461
base64_image = pil_to_base64(prompt["image_path"])
455462

456463
# Open the local image file in binary mode and encode it as base64
457464
assert os.path.exists(prompt["image_path"])
458465
with open(prompt["image_path"], "rb") as image_file:
459466
image_data = base64.b64encode(image_file.read()).decode("utf-8")
460-
if api_type == 'anthropic':
467+
if api_type == "anthropic":
461468
conv = [
462469
{"role": "system", "content": self.system_prompt},
463470
{
464471
"role": "user",
465472
"content": [
466-
{
473+
{
467474
"type": "image",
468475
"source": {
469476
"type": "base64",
@@ -492,13 +499,13 @@ def pre_process(self, prompt, api_type='openai'):
492499
},
493500
]
494501
return conv
495-
502+
496503
def post_process(self, judgment):
497504
score = self.get_score(judgment=judgment)
498505
return {"homework": bool(score == "yes") if score else False}
499-
500-
class CategoryDiagram(Category):
501506

507+
508+
class CategoryDiagram(Category):
502509
def __init__(self):
503510
super().__init__()
504511
self.name_tag = "diagram_v0.1"
@@ -523,21 +530,21 @@ def get_score(self, judgment):
523530
return matches[0]
524531
else:
525532
return None
526-
527-
def pre_process(self, prompt, api_type='openai'):
533+
534+
def pre_process(self, prompt, api_type="openai"):
528535
base64_image = pil_to_base64(prompt["image_path"])
529536

530537
# Open the local image file in binary mode and encode it as base64
531538
assert os.path.exists(prompt["image_path"])
532539
with open(prompt["image_path"], "rb") as image_file:
533540
image_data = base64.b64encode(image_file.read()).decode("utf-8")
534-
if api_type == 'anthropic':
541+
if api_type == "anthropic":
535542
conv = [
536543
{"role": "system", "content": self.system_prompt},
537544
{
538545
"role": "user",
539546
"content": [
540-
{
547+
{
541548
"type": "image",
542549
"source": {
543550
"type": "base64",
@@ -566,7 +573,7 @@ def pre_process(self, prompt, api_type='openai'):
566573
},
567574
]
568575
return conv
569-
576+
570577
def post_process(self, judgment):
571578
score = self.get_score(judgment=judgment)
572579
return {"diagram": bool(score == "yes") if score else False}

fastchat/serve/monitor/classify/label.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def chat_completion_openai(model, messages, temperature, max_tokens, api_dict=No
8888

8989
return output
9090

91+
9192
def chat_completion_anthropic(model, messages, temperature, max_tokens, api_dict=None):
9293
import anthropic
9394

@@ -111,7 +112,7 @@ def chat_completion_anthropic(model, messages, temperature, max_tokens, api_dict
111112
stop_sequences=[anthropic.HUMAN_PROMPT],
112113
max_tokens=max_tokens,
113114
temperature=temperature,
114-
system=sys_msg
115+
system=sys_msg,
115116
)
116117
output = response.content[0].text
117118
break
@@ -120,7 +121,10 @@ def chat_completion_anthropic(model, messages, temperature, max_tokens, api_dict
120121
time.sleep(API_RETRY_SLEEP)
121122
return output
122123

123-
def chat_completion_gemini(model, messages, temperature, max_tokens, api_dict=None, image_path=None):
124+
125+
def chat_completion_gemini(
126+
model, messages, temperature, max_tokens, api_dict=None, image_path=None
127+
):
124128
import google
125129
import google.generativeai as genai
126130
from google.generativeai.types import HarmCategory, HarmBlockThreshold
@@ -139,9 +143,9 @@ def chat_completion_gemini(model, messages, temperature, max_tokens, api_dict=No
139143

140144
prompt = messages[0]["content"]
141145
if type(prompt) == list:
142-
prompt = [prompt[0]['text'], Image.open(image_path).convert('RGB')]
146+
prompt = [prompt[0]["text"], Image.open(image_path).convert("RGB")]
143147

144-
safety_settings={
148+
safety_settings = {
145149
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
146150
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
147151
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
@@ -156,7 +160,9 @@ def chat_completion_gemini(model, messages, temperature, max_tokens, api_dict=No
156160
gemini.temperature = temperature
157161
response = gemini.generate_content(prompt, safety_settings=safety_settings)
158162
if response.candidates[0].finish_reason != 1:
159-
print(f"Gemini did not finish generating content: {response.candidates[0].finish_reason}")
163+
print(
164+
f"Gemini did not finish generating content: {response.candidates[0].finish_reason}"
165+
)
160166
output = "Gemini did not finish generating content"
161167
else:
162168
output = response.text
@@ -215,7 +221,7 @@ def get_answer(
215221
temperature=temperature,
216222
max_tokens=max_tokens,
217223
api_dict=api_dict,
218-
image_path=question.get("image_path")
224+
image_path=question.get("image_path"),
219225
)
220226
else:
221227
raise ValueError(f"api_type {api_type} not supported")
@@ -309,7 +315,9 @@ def find_required_tasks(row):
309315
input_data["image_hash"] = input_data.conversation_a.map(
310316
lambda convo: convo[0]["content"][1][0]
311317
)
312-
input_data["image_path"] = input_data.image_hash.map(lambda x: f"{config['image_dir']}/{x}.png")
318+
input_data["image_path"] = input_data.image_hash.map(
319+
lambda x: f"{config['image_dir']}/{x}.png"
320+
)
313321

314322
if config["cache_file"]:
315323
print("loading cache data")
@@ -360,11 +368,15 @@ def find_required_tasks(row):
360368

361369
if args.vision:
362370
not_labeled["prompt"] = not_labeled.conversation_a.map(
363-
lambda convo: "\n".join([convo[i]["content"][0] for i in range(0, len(convo), 2)])
371+
lambda convo: "\n".join(
372+
[convo[i]["content"][0] for i in range(0, len(convo), 2)]
373+
)
364374
)
365375
else:
366376
not_labeled["prompt"] = not_labeled.conversation_a.map(
367-
lambda convo: "\n".join([convo[i]["content"] for i in range(0, len(convo), 2)])
377+
lambda convo: "\n".join(
378+
[convo[i]["content"] for i in range(0, len(convo), 2)]
379+
)
368380
)
369381
not_labeled["prompt"] = not_labeled.prompt.map(lambda x: x[:12500])
370382

0 commit comments

Comments
 (0)