-
Notifications
You must be signed in to change notification settings - Fork 4
feat: add support for the original mt-bench #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 31 commits
ba4220d
d2a5a42
a828adb
0dcebf9
d60073b
38f63ee
6f5e0fc
42ff2ae
8fcb032
df958af
35856f2
6a11182
fecd3ed
0b4eaec
29340b0
2c294f1
4be61bf
51d2597
8dee7b2
fdc9410
48c5373
648a9be
14f747e
e67ea79
4089be8
03f5cce
8ffe3a6
b877f11
c2056b5
41cd15d
0ca66c5
a295305
0fb9700
e5670ea
6dd78fd
0094eea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,17 +1,59 @@ | ||
| import pandas as pd | ||
| from langchain.prompts import ChatPromptTemplate | ||
| from typing import Any | ||
|
|
||
| from openjury.utils import ( | ||
| do_inference, | ||
| make_model, | ||
| truncate, | ||
| ) | ||
|
|
||
|
|
||
| def truncate(s: str, max_len: int | None = None): | ||
| if max_len is not None: | ||
| return s[:max_len] | ||
| else: | ||
| return s | ||
| def _set_temperature_on_model(chat_model, temperature: float) -> None: | ||
| if hasattr(chat_model, "set_temperature"): | ||
| chat_model.set_temperature(temperature) | ||
| return | ||
| if hasattr(chat_model, "temperature"): | ||
| setattr(chat_model, "temperature", temperature) | ||
|
|
||
|
|
||
| def _infer_grouped_by_temperature( | ||
| *, | ||
| model_spec: str, | ||
| provider: str, | ||
| max_tokens: int | None, | ||
| model_kwargs: dict[str, Any], | ||
| base_model, | ||
| inputs: list, | ||
| temperatures: list[float], | ||
| use_tqdm: bool, | ||
| ) -> list[str]: | ||
| outputs: list[str] = [""] * len(inputs) | ||
| groups: dict[float, list[int]] = {} | ||
| for idx, temp in enumerate(temperatures): | ||
| groups.setdefault(float(temp), []).append(idx) | ||
|
|
||
| for temp in sorted(groups.keys()): | ||
| idxs = groups[temp] | ||
| group_inputs = [inputs[i] for i in idxs] | ||
|
|
||
| if provider in {"VLLM", "LlamaCpp"}: | ||
| _set_temperature_on_model(base_model, temp) | ||
| group_model = base_model | ||
| else: | ||
| group_model = make_model( | ||
| model_spec, max_tokens=max_tokens, temperature=temp, **model_kwargs | ||
| ) | ||
|
|
||
| group_outs = do_inference( | ||
| chat_model=group_model, | ||
| inputs=group_inputs, | ||
| use_tqdm=use_tqdm, | ||
| ) | ||
| for i, out in zip(idxs, group_outs): | ||
| outputs[i] = out | ||
|
|
||
| return outputs | ||
|
|
||
|
|
||
| def generate_instructions( | ||
|
|
@@ -57,6 +99,136 @@ def generate_instructions( | |
| return df_outputs | ||
|
|
||
|
|
||
| def generate_multiturn( | ||
| questions: pd.DataFrame, | ||
| model: str, | ||
| truncate_input_chars: int | None = 8192, | ||
| max_tokens: int | None = 8192, | ||
| use_tqdm: bool = True, | ||
| temperature_config: dict[str, float] | None = None, | ||
| **model_kwargs, | ||
| ) -> pd.DataFrame: | ||
| """Generate two-turn completions for MT-Bench style questions. | ||
|
|
||
| Generates turn 1 answers first, then uses them as conversation context | ||
| to generate turn 2 answers. | ||
|
|
||
| Args: | ||
| questions: DataFrame with columns turn_1, turn_2, and index instruction_index. | ||
| model: Model specification string (e.g. "VLLM/model-name"). | ||
| temperature_config: Optional category -> temperature mapping. When set, | ||
| inputs are inferred in temperature-homogeneous groups to match | ||
| MT-Bench/FastChat category defaults. | ||
| **model_kwargs: Provider-specific options forwarded to make_model | ||
| (e.g. max_model_len, chat_template for VLLM). | ||
| Returns: | ||
| DataFrame with columns: instruction_index, completion_turn_1, completion_turn_2 | ||
| """ | ||
| provider = model.split("/")[0] | ||
| use_category_temperatures = temperature_config is not None | ||
| local_provider = provider in {"VLLM", "LlamaCpp"} | ||
|
|
||
| chat_model = None | ||
| if use_category_temperatures and local_provider: | ||
| chat_model = make_model(model, max_tokens=max_tokens, temperature=0.0, **model_kwargs) | ||
| else: | ||
| chat_model = make_model(model, max_tokens=max_tokens, **model_kwargs) | ||
|
|
||
| system_prompt = "You are a helpful assistant." | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can use a better
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, we have a naive default also in general (it is not blocking for this PR as we can change/improve it later).
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added the |
||
| idxs = questions.index.tolist() | ||
| temperatures: list[float] = [] | ||
| if use_category_temperatures: | ||
| temperatures = [ | ||
| temperature_config.get(str(questions.loc[idx].get("category") or ""), 0.7) | ||
| for idx in idxs | ||
| ] | ||
|
|
||
| turn1_template = ChatPromptTemplate.from_messages( | ||
| [("system", system_prompt), ("user", "{user_prompt}")] | ||
| ) | ||
|
|
||
| turn1_inputs = turn1_template.batch( | ||
| [ | ||
| {"user_prompt": truncate(row["turn_1"], max_len=truncate_input_chars)} | ||
| for _, row in questions.iterrows() | ||
| ] | ||
| ) | ||
|
|
||
| print(f"Generating turn 1 completions ({len(turn1_inputs)} questions).") | ||
| if use_category_temperatures: | ||
| completions_turn_1 = _infer_grouped_by_temperature( | ||
| model_spec=model, | ||
| provider=provider, | ||
| max_tokens=max_tokens, | ||
| model_kwargs=model_kwargs, | ||
| base_model=chat_model, | ||
| inputs=turn1_inputs, | ||
| temperatures=temperatures, | ||
| use_tqdm=use_tqdm, | ||
| ) | ||
| else: | ||
| completions_turn_1 = do_inference( | ||
| chat_model=chat_model, | ||
| inputs=turn1_inputs, | ||
| use_tqdm=use_tqdm, | ||
| ) | ||
|
|
||
| turn2_inputs = [] | ||
| for (_, row), t1_answer in zip(questions.iterrows(), completions_turn_1): | ||
| if row["turn_2"] is None: | ||
| turn2_inputs.append( | ||
| turn1_template.invoke( | ||
| {"user_prompt": "No follow-up question."} | ||
| ) | ||
| ) | ||
| else: | ||
| multi_turn_template = ChatPromptTemplate.from_messages( | ||
| [ | ||
| ("system", system_prompt), | ||
| ("user", "{turn_1}"), | ||
| ("assistant", "{turn_1_answer}"), | ||
| ("user", "{turn_2}"), | ||
| ] | ||
| ) | ||
| turn2_inputs.append( | ||
| multi_turn_template.invoke( | ||
| { | ||
| "turn_1": truncate(row["turn_1"], max_len=truncate_input_chars), | ||
| "turn_1_answer": truncate(str(t1_answer), max_len=truncate_input_chars), | ||
| "turn_2": truncate(row["turn_2"], max_len=truncate_input_chars), | ||
| } | ||
| ) | ||
| ) | ||
|
|
||
| print(f"Generating turn 2 completions ({len(turn2_inputs)} questions).") | ||
| if use_category_temperatures: | ||
| completions_turn_2 = _infer_grouped_by_temperature( | ||
| model_spec=model, | ||
| provider=provider, | ||
| max_tokens=max_tokens, | ||
| model_kwargs=model_kwargs, | ||
| base_model=chat_model, | ||
| inputs=turn2_inputs, | ||
| temperatures=temperatures, | ||
| use_tqdm=use_tqdm, | ||
| ) | ||
| else: | ||
| completions_turn_2 = do_inference( | ||
| chat_model=chat_model, | ||
| inputs=turn2_inputs, | ||
| use_tqdm=use_tqdm, | ||
| ) | ||
|
|
||
| df_outputs = pd.DataFrame( | ||
| data={ | ||
| "instruction_index": idxs, | ||
| "completion_turn_1": completions_turn_1, | ||
| "completion_turn_2": completions_turn_2, | ||
| }, | ||
| ) | ||
| return df_outputs | ||
|
|
||
|
|
||
| def generate_base( | ||
| instructions: pd.Series, | ||
| model: str, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💪