Skip to content

Commit bb882c3

Browse files
authored
Vision Categories! (#3639)
1 parent 33bd3d9 commit bb882c3

File tree

8 files changed

+644
-233
lines changed

8 files changed

+644
-233
lines changed

fastchat/serve/monitor/classify/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ To test your new classifier for a new category, you would have to make sure you
2929
python label.py --config config.yaml --testing
3030
```
3131

32+
If you are labeling a vision category, add the `--vision` flag to the command. This will add a new column to the input data called `image_path` that contains the path to the image corresponding to each conversation. Ensure that you update your config with the correct `image_dir` where the images are stored.
33+
3234
Then, add your new category bench to `tag_names` in `display_score.py`. After making sure that you also have a correctly formatted ground truth json file, you can report the performance of your classifier by running
3335
```console
3436
python display_score.py --bench <your_bench>

fastchat/serve/monitor/classify/category.py

Lines changed: 407 additions & 4 deletions
Large diffs are not rendered by default.

fastchat/serve/monitor/classify/config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,16 @@ task_name:
1414

1515
model_name: null
1616
name: llama-3-70b-instruct
17+
api_type: openai
1718
endpoints:
1819
- api_base: null
1920
api_key: null
2021
parallel: 50
2122
temperature: 0.0
2223
max_token: 512
2324

25+
image_dir: null # directory where vision arena images are stored
26+
2427
max_retry: 2
2528
retry_sleep: 10
2629
error_output: $ERROR$

fastchat/serve/monitor/classify/label.py

Lines changed: 141 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,95 @@ def chat_completion_openai(model, messages, temperature, max_tokens, api_dict=No
8989
return output
9090

9191

92+
def chat_completion_anthropic(model, messages, temperature, max_tokens, api_dict=None):
93+
import anthropic
94+
95+
if api_dict:
96+
api_key = api_dict["api_key"]
97+
else:
98+
api_key = os.environ["ANTHROPIC_API_KEY"]
99+
100+
sys_msg = ""
101+
if messages[0]["role"] == "system":
102+
sys_msg = messages[0]["content"]
103+
messages = messages[1:]
104+
105+
output = API_ERROR_OUTPUT
106+
for _ in range(API_MAX_RETRY):
107+
try:
108+
c = anthropic.Anthropic(api_key=api_key)
109+
response = c.messages.create(
110+
model=model,
111+
messages=messages,
112+
stop_sequences=[anthropic.HUMAN_PROMPT],
113+
max_tokens=max_tokens,
114+
temperature=temperature,
115+
system=sys_msg,
116+
)
117+
output = response.content[0].text
118+
break
119+
except anthropic.APIError as e:
120+
print(type(e), e)
121+
time.sleep(API_RETRY_SLEEP)
122+
return output
123+
124+
125+
def chat_completion_gemini(
126+
model, messages, temperature, max_tokens, api_dict=None, image_path=None
127+
):
128+
import google
129+
import google.generativeai as genai
130+
from google.generativeai.types import HarmCategory, HarmBlockThreshold
131+
from PIL import Image
132+
133+
if api_dict:
134+
api_key = api_dict["api_key"]
135+
genai.configure(api_key=api_key)
136+
else:
137+
genai.configure(api_key=os.environ["GENAI_API_KEY"])
138+
139+
sys_msg = ""
140+
if messages[0]["role"] == "system":
141+
sys_msg = messages[0]["content"]
142+
messages = messages[1:]
143+
144+
prompt = messages[0]["content"]
145+
if type(prompt) == list:
146+
prompt = [prompt[0]["text"], Image.open(image_path).convert("RGB")]
147+
148+
safety_settings = {
149+
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
150+
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
151+
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
152+
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
153+
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
154+
}
155+
output = API_ERROR_OUTPUT
156+
for _ in range(API_MAX_RETRY):
157+
try:
158+
gemini = genai.GenerativeModel(model, system_instruction=sys_msg)
159+
gemini.max_output_tokens = max_tokens
160+
gemini.temperature = temperature
161+
response = gemini.generate_content(prompt, safety_settings=safety_settings)
162+
if response.candidates[0].finish_reason != 1:
163+
print(
164+
f"Gemini did not finish generating content: {response.candidates[0].finish_reason}"
165+
)
166+
output = "Gemini did not finish generating content"
167+
else:
168+
output = response.text
169+
break
170+
except google.api_core.exceptions.ResourceExhausted as e:
171+
# THIS IS A TEMPORARY FIX
172+
print(type(e), e)
173+
time.sleep(API_RETRY_SLEEP)
174+
except Exception as e:
175+
# THIS IS A TEMPORARY FIX
176+
print(type(e), e)
177+
time.sleep(API_RETRY_SLEEP)
178+
return output
179+
180+
92181
def get_answer(
93182
question: dict,
94183
model_name: str,
@@ -98,6 +187,7 @@ def get_answer(
98187
api_dict: dict,
99188
categories: list,
100189
testing: bool,
190+
api_type: str,
101191
):
102192
if "category_tag" in question:
103193
category_tag = question["category_tag"]
@@ -107,14 +197,34 @@ def get_answer(
107197
output_log = {}
108198

109199
for category in categories:
110-
conv = category.pre_process(question["prompt"])
111-
output = chat_completion_openai(
112-
model=model_name,
113-
messages=conv,
114-
temperature=temperature,
115-
max_tokens=max_tokens,
116-
api_dict=api_dict,
117-
)
200+
conv = category.pre_process(question)
201+
if api_type == "openai":
202+
output = chat_completion_openai(
203+
model=model_name,
204+
messages=conv,
205+
temperature=temperature,
206+
max_tokens=max_tokens,
207+
api_dict=api_dict,
208+
)
209+
elif api_type == "anthropic":
210+
output = chat_completion_anthropic(
211+
model=model_name,
212+
messages=conv,
213+
temperature=temperature,
214+
max_tokens=max_tokens,
215+
api_dict=api_dict,
216+
)
217+
elif api_type == "gemini":
218+
output = chat_completion_gemini(
219+
model=model_name,
220+
messages=conv,
221+
temperature=temperature,
222+
max_tokens=max_tokens,
223+
api_dict=api_dict,
224+
image_path=question.get("image_path"),
225+
)
226+
else:
227+
raise ValueError(f"api_type {api_type} not supported")
118228
# Dump answers
119229
category_tag[category.name_tag] = category.post_process(output)
120230

@@ -169,6 +279,7 @@ def find_required_tasks(row):
169279
parser = argparse.ArgumentParser()
170280
parser.add_argument("--config", type=str, required=True)
171281
parser.add_argument("--testing", action="store_true")
282+
parser.add_argument("--vision", action="store_true")
172283
args = parser.parse_args()
173284

174285
enter = input(
@@ -199,6 +310,15 @@ def find_required_tasks(row):
199310
assert len(input_data) == len(input_data.uid.unique())
200311
print(f"{len(input_data)}# of input data just loaded")
201312

313+
if args.vision:
314+
old_len = len(input_data)
315+
input_data["image_hash"] = input_data.conversation_a.map(
316+
lambda convo: convo[0]["content"][1][0]
317+
)
318+
input_data["image_path"] = input_data.image_hash.map(
319+
lambda x: f"{config['image_dir']}/{x}.png"
320+
)
321+
202322
if config["cache_file"]:
203323
print("loading cache data")
204324
with open(config["cache_file"], "rb") as f:
@@ -246,9 +366,18 @@ def find_required_tasks(row):
246366
f"{name}: {len(not_labeled[not_labeled.required_tasks.map(lambda tasks: name in tasks)])}"
247367
)
248368

249-
not_labeled["prompt"] = not_labeled.conversation_a.map(
250-
lambda convo: "\n".join([convo[i]["content"] for i in range(0, len(convo), 2)])
251-
)
369+
if args.vision:
370+
not_labeled["prompt"] = not_labeled.conversation_a.map(
371+
lambda convo: "\n".join(
372+
[convo[i]["content"][0] for i in range(0, len(convo), 2)]
373+
)
374+
)
375+
else:
376+
not_labeled["prompt"] = not_labeled.conversation_a.map(
377+
lambda convo: "\n".join(
378+
[convo[i]["content"] for i in range(0, len(convo), 2)]
379+
)
380+
)
252381
not_labeled["prompt"] = not_labeled.prompt.map(lambda x: x[:12500])
253382

254383
with concurrent.futures.ThreadPoolExecutor(
@@ -270,6 +399,7 @@ def find_required_tasks(row):
270399
if category.name_tag in row["required_tasks"]
271400
],
272401
args.testing,
402+
config["api_type"],
273403
)
274404
futures.append(future)
275405
for future in tqdm.tqdm(
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Yaml config file for category classification
2+
3+
input_file: null # json
4+
cache_file: null # json
5+
output_file: null # json line
6+
7+
convert_to_json: True
8+
9+
task_name:
10+
- captioning_v0.1
11+
- homework_v0.1
12+
- ocr_v0.1
13+
- humor_v0.1
14+
- entity_recognition_v0.1
15+
- creative_writing_vision_v0.1
16+
- diagram_v0.1
17+
18+
19+
model_name: null
20+
name: gemini-1.5-flash
21+
api_type: gemini
22+
endpoints:
23+
- api_base: null
24+
api_key: null
25+
26+
parallel: 50
27+
temperature: 0.0
28+
max_token: 512
29+
30+
image_dir: null # directory where vision arena images are stored
31+
32+
max_retry: 2
33+
retry_sleep: 10
34+
error_output: $ERROR$

0 commit comments

Comments
 (0)