Skip to content

Commit 807b66f

Browse files
committed
Merge
2 parents c90b8fc + 853168f commit 807b66f

File tree

6 files changed

+123
-22
lines changed

6 files changed

+123
-22
lines changed

fastchat/conversation.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,31 @@ def to_reka_api_messages(self):
576576

577577
return ret
578578

579+
def to_metagen_api_messages(self):
580+
"""Convert the conversation to MetaGen (Meta) chat completion format."""
581+
if self.system_message == "":
582+
ret = []
583+
else:
584+
ret = [{"role": "system", "text": self.system_message}]
585+
586+
for i, (_, msg) in enumerate(self.messages[self.offset :]):
587+
if i % 2 == 0:
588+
if type(msg) is tuple:
589+
text, images = msg[0], msg[1]
590+
# Currently only support one image.
591+
attachment = {
592+
"type": "base64_image",
593+
"mime": "image/jpeg",
594+
"data": images[-1].base64_str,
595+
}
596+
ret.append({"role": "user", "text": text, "attachment": attachment})
597+
else:
598+
ret.append({"role": "user", "text": msg})
599+
else:
600+
if msg is not None:
601+
ret.append({"role": "ai", "text": msg})
602+
return ret
603+
579604
def save_new_images(self, has_csam_images=False, use_remote_storage=False):
580605
import hashlib
581606
from fastchat.constants import LOGDIR

fastchat/serve/api_provider.py

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,17 @@ def get_api_provider_stream_iter(
203203
api_base=model_api_dict["api_base"],
204204
api_key=model_api_dict["api_key"],
205205
)
206+
elif model_api_dict["api_type"] == "metagen":
207+
prompt = conv.to_metagen_api_messages()
208+
stream_iter = metagen_api_stream_iter(
209+
model_api_dict["model_name"],
210+
prompt,
211+
temperature,
212+
top_p,
213+
max_new_tokens,
214+
api_base=model_api_dict["api_base"],
215+
api_key=model_api_dict["api_key"],
216+
)
206217
else:
207218
raise NotImplementedError()
208219

@@ -1115,11 +1126,62 @@ def reka_api_stream_iter(
11151126
model=model_name,
11161127
)
11171128

1118-
for chunk in response:
1119-
try:
1120-
yield {"text": chunk.responses[0].chunk.content, "error_code": 0}
1121-
except:
1122-
yield {
1123-
"text": f"**API REQUEST ERROR** ",
1124-
"error_code": 1,
1129+
if response.status_code != 200:
1130+
error_message = response.text
1131+
logger.error(f"==== error from reka api: {error_message} ====")
1132+
yield {
1133+
"text": f"**API REQUEST ERROR** Reason: {error_message}",
1134+
"error_code": 1,
1135+
}
1136+
return
1137+
1138+
for line in response.iter_lines():
1139+
line = line.decode("utf8")
1140+
if not line.startswith("data: "):
1141+
continue
1142+
gen = json.loads(line[6:])
1143+
yield {"text": gen["text"], "error_code": 0}
1144+
1145+
1146+
def metagen_api_stream_iter(
1147+
model_name,
1148+
messages,
1149+
temperature,
1150+
top_p,
1151+
max_new_tokens,
1152+
api_key,
1153+
api_base,
1154+
):
1155+
res = requests.post(
1156+
f"{api_base}/chat_stream_completions?access_token={api_key}",
1157+
stream=True,
1158+
headers={"Content-Type": "application/json"},
1159+
json={
1160+
"model": model_name,
1161+
"chunks_delimited": True,
1162+
"messages": messages,
1163+
"options": {
1164+
"max_tokens": max_new_tokens,
1165+
"generation_algorithm": "top_p",
1166+
"top_p": top_p,
1167+
"temperature": temperature,
1168+
},
1169+
},
1170+
timeout=40,
1171+
)
1172+
1173+
if res.status_code != 200:
1174+
logger.error(f"Unexpected response ({res.status_code}): {res.text}")
1175+
raise ValueError("Unexpected response: ", res.json())
1176+
1177+
text = ""
1178+
for line in res.iter_lines():
1179+
if line:
1180+
part = json.loads(line.decode("utf-8"))
1181+
if "text" in part:
1182+
text += part["text"]
1183+
data = {
1184+
"text": text,
1185+
"error_code": 0,
11251186
}
1187+
yield data

fastchat/serve/monitor/elo_analysis.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -495,18 +495,33 @@ def construct_style_matrices(
495495
return X, Y, models
496496

497497

498-
def get_bootstrap_result_style_control(X, Y, models, func_compute_elo, num_round=1000):
498+
def get_bootstrap_result_style_control(
499+
X, Y, battles, models, func_compute_elo, num_round=1000
500+
):
499501
elos = []
500502
coefs = []
501503
assert X.shape[0] % 2 == 0 and X.shape[0] == Y.shape[0]
502504
k = int(
503505
X.shape[0] / 2
504506
) # Since we duplicate the battles when constructing X and Y, we don't want to sample the duplicates
505507

508+
battles_tie_idx = (battles["winner"] == "tie") | (
509+
battles["winner"] == "tie (bothbad)"
510+
)
506511
for _ in tqdm(range(num_round), desc="bootstrap"):
507512
indices = np.random.choice(list(range(k)), size=(k), replace=True)
508-
_X = np.concatenate([X[indices], X[indices]])
509-
_Y = np.concatenate([Y[indices], Y[indices]])
513+
514+
index2tie = np.zeros(k, dtype=bool)
515+
index2tie[battles_tie_idx] = True
516+
517+
nontie_indices = indices[~index2tie[indices]]
518+
tie_indices = np.concatenate(
519+
[indices[index2tie[indices]], indices[index2tie[indices]] + k]
520+
)
521+
522+
_X = np.concatenate([X[nontie_indices], X[nontie_indices], X[tie_indices]])
523+
_Y = np.concatenate([Y[nontie_indices], Y[nontie_indices], Y[tie_indices]])
524+
510525
assert _X.shape == X.shape and _Y.shape == Y.shape
511526

512527
states = ~_X[:, : len(models)].any(axis=0)
@@ -585,7 +600,7 @@ def report_elo_analysis_results(
585600
if style_control:
586601
X, Y, models = construct_style_matrices(battles)
587602
bootstrap_df, boostrap_coef = get_bootstrap_result_style_control(
588-
X, Y, models, fit_mle_elo, num_round=num_bootstrap
603+
X, Y, battles, models, fit_mle_elo, num_round=num_bootstrap
589604
)
590605
elo_rating_final, coef_final = fit_mle_elo(X, Y, models)
591606
else:

fastchat/serve/monitor/monitor_md.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66

77
key_to_category_name = {
88
"full": "Overall",
9+
"full_style_control": "Overall w/ Style Control",
910
"dedup": "De-duplicate Top Redundant Queries (soon to be default)",
1011
"math": "Math",
1112
"if": "Instruction Following",
1213
"multiturn": "Multi-Turn",
1314
"coding": "Coding",
1415
"hard_6": "Hard Prompts (Overall)",
1516
"hard_english_6": "Hard Prompts (English)",
17+
"hard_6_style_control": "Hard Prompts (Overall) w/ Style Control",
1618
"long_user": "Longer Query",
1719
"english": "English",
1820
"chinese": "Chinese",
@@ -30,12 +32,14 @@
3032
}
3133
cat_name_to_explanation = {
3234
"Overall": "Overall Questions",
35+
"Overall w/ Style Control": "Overall with Style Control",
3336
"De-duplicate Top Redundant Queries (soon to be default)": "De-duplicate top redundant queries (top 0.1%). See details in [blog post](https://lmsys.org/blog/2024-05-17-category-hard/#note-enhancing-quality-through-de-duplication).",
3437
"Math": "Math",
3538
"Instruction Following": "Instruction Following",
3639
"Multi-Turn": "Multi-Turn Conversation (>= 2 turns)",
3740
"Coding": "Coding: whether conversation contains code snippets",
3841
"Hard Prompts (Overall)": "Hard Prompts (Overall): details in [blog post](https://lmsys.org/blog/2024-05-17-category-hard/)",
42+
"Hard Prompts (Overall) w/ Style Control": "Hard Prompts (Overall) with Style Control",
3943
"Hard Prompts (English)": "Hard Prompts (English), note: the delta is to English Category. details in [blog post](https://lmsys.org/blog/2024-05-17-category-hard/)",
4044
"Longer Query": "Longer Query (>= 500 tokens)",
4145
"English": "English Prompts",

fastchat/serve/vision/create_vqa_examples_dir.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,6 @@ def download_images_and_create_json(
6464
args = parser.parse_args()
6565

6666
datasets_info = {
67-
"realworldqa": {
68-
"path": "visheratin/realworldqa",
69-
"image_key": "image",
70-
"question_key": "question",
71-
"id_key": "index",
72-
"subset": False,
73-
"split": "test",
74-
},
7567
"Memes": {
7668
"path": "not-lain/meme-dataset",
7769
"image_key": "image",

fastchat/serve/vision/create_vqa_examples_json.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,22 @@
1717
args = parser.parse_args()
1818

1919
dataset_prop = {
20-
"realworldqa": 500,
2120
"Memes": 500,
2221
"Floorplan": 500,
2322
"Website": 500,
24-
"IllusionVQA": 500,
23+
"IllusionVQA": 435,
2524
"NewYorker": 500,
2625
}
2726

2827
dataset_json = []
2928
for dataset_name in dataset_prop.keys():
3029
with open(f"{args.output_dir}/{dataset_name}/data.json") as f:
3130
data = json.load(f)
32-
dataset_json.extend(np.random.choice(data, dataset_prop[dataset_name]))
31+
dataset_json.extend(
32+
np.random.choice(
33+
data, min(dataset_prop[dataset_name], len(data)), replace=False
34+
)
35+
)
3336

3437
with open(f"{args.output_dir}/metadata_sampled.json", "w") as f:
3538
json.dump(dataset_json, f, indent=4)

0 commit comments

Comments
 (0)