-
Notifications
You must be signed in to change notification settings - Fork 4
feat: add support for the original mt-bench #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ba4220d
d2a5a42
a828adb
0dcebf9
d60073b
38f63ee
6f5e0fc
42ff2ae
8fcb032
df958af
35856f2
6a11182
fecd3ed
0b4eaec
29340b0
2c294f1
4be61bf
51d2597
8dee7b2
fdc9410
48c5373
648a9be
14f747e
e67ea79
4089be8
03f5cce
8ffe3a6
b877f11
c2056b5
41cd15d
0ca66c5
a295305
0fb9700
e5670ea
6dd78fd
0094eea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,212 @@ | ||
| """CLI argument configuration for generation and evaluation entrypoints.""" | ||
|
|
||
| import argparse | ||
| import json | ||
| from dataclasses import dataclass, field | ||
|
|
||
|
|
||
| @dataclass | ||
| class CliArgs: | ||
| dataset: str | ||
| model_A: str | ||
| model_B: str | ||
| judge_model: str | ||
|
|
||
| n_instructions: int | None = None | ||
| provide_explanation: bool = False | ||
| swap_mode: str = "fixed" | ||
| ignore_cache: bool = False | ||
| use_tqdm: bool = False | ||
| truncate_all_input_chars: int = 8192 | ||
| max_out_tokens_models: int = 32768 | ||
| max_out_tokens_judge: int = 32768 | ||
| max_model_len: int | None = None | ||
| chat_template: str | None = None | ||
| mt_bench_turns: str = "both" | ||
| mt_bench_compatibility: str = "openjury" | ||
|
Comment on lines
+25
to
+26
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right now this CLI is unioning both single and multi turn. We will also have specific arguments, for instance for ELO arena. We can either have one class for each entrypoint (single turn, multi turn, ELO rating, ...) or a class with common CLI args and another one for specific args. Either option would work for me but I would rather avoid an approach that puts everything in this class.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That being said, we could also do this in another PR to avoid blocking on this given that this is internal stuff. |
||
| result_folder: str = "results" | ||
| engine_kwargs: dict = field(default_factory=dict) | ||
|
|
||
| def __post_init__(self): | ||
| supported_modes = ["fixed", "both"] | ||
| assert ( | ||
| self.swap_mode in supported_modes | ||
| ), f"Only {supported_modes} modes are supported but got {self.swap_mode}." | ||
| supported_mt_bench_modes = ["openjury", "fastchat"] | ||
| assert ( | ||
| self.mt_bench_compatibility in supported_mt_bench_modes | ||
| ), f"Only {supported_mt_bench_modes} are supported but got {self.mt_bench_compatibility}." | ||
|
|
||
| @classmethod | ||
| def parse_args(cls): | ||
| parser = argparse.ArgumentParser( | ||
| prog="Generate completion and evaluate with a judge", | ||
| ) | ||
| parser.add_argument( | ||
| "--dataset", | ||
| help="The dataset to use. For instance `alpaca-eval`, `arena-hard`, `m-arena-hard-EU` for instruction " | ||
| "tuning cases or `french-contexts`, `spanish-contexts` for base models.", | ||
| ) | ||
| parser.add_argument( | ||
| "--model_A", | ||
| required=True, | ||
| help="Name of the LLM to use for a generation, must be a valid choice for `generation_provider`", | ||
| ) | ||
| parser.add_argument( | ||
| "--model_B", | ||
| required=True, | ||
| help="Name of the LLM to use for a generation, must be a valid choice for `generation_provider`", | ||
| ) | ||
| parser.add_argument( | ||
| "--judge_model", | ||
| required=True, | ||
| help="Name of the LLM to use, for instance `Together/meta-llama/Meta-Llama-3-70B-Instruct-Turbo`, " | ||
| "`VLLM/meta-llama/Meta-Llama-3-70B-Instruct-Turbo`, `LangChain/LocalPath` etc", | ||
| ) | ||
| parser.add_argument( | ||
| "--n_instructions", | ||
| type=int, | ||
| required=False, | ||
| ) | ||
| parser.add_argument( | ||
| "--provide_explanation", | ||
| action="store_true", | ||
| help="If specified, judge will provide explanation before making a judgement. Does not necessarily improve" | ||
| "the accuracy of the judge but enables some result interpretation.", | ||
| ) | ||
| parser.add_argument( | ||
| "--swap_mode", | ||
| type=str, | ||
| choices=["fixed", "both"], | ||
| default="fixed", | ||
| help="Model comparison order mode. 'fixed': always use model order A-B. 'both': correct for model order " | ||
| "bias by evaluating each instruction twice, once as A-B and once as B-A, and average. This helps account " | ||
| "for judge position bias. Default is 'fixed'.", | ||
| ) | ||
| parser.add_argument( | ||
| "--ignore_cache", | ||
| action="store_true", | ||
| help="If specified, ignore cache of previous completions.", | ||
| ) | ||
| parser.add_argument( | ||
| "--use_tqdm", | ||
| action="store_true", | ||
| help="If specified, use tqdm, does not work with all model providers, vLLM in particular.", | ||
| ) | ||
| parser.add_argument( | ||
| "--result_folder", | ||
| type=str, | ||
| required=False, | ||
| default="results", | ||
| help="The folder to save the results. Defaults to `results`. Evaluation results will be saved in" | ||
| " `[result_folder]/[evaluation_name]`.", | ||
| ) | ||
| parser.add_argument( | ||
| "--truncate_all_input_chars", | ||
| type=int, | ||
| required=False, | ||
| default=8192, | ||
| help="Character-level truncation applied before tokenization: truncates each instruction " | ||
| "before model A/B generation and truncates each completion before judge evaluation.", | ||
| ) | ||
| parser.add_argument( | ||
| "--max_out_tokens_models", | ||
| type=int, | ||
| required=False, | ||
| default=32768, | ||
| help=( | ||
| "Generation token budget for each model A/B response. For VLLM, keep this <= " | ||
| "--max_model_len (if provided)." | ||
| ), | ||
| ) | ||
| parser.add_argument( | ||
| "--max_out_tokens_judge", | ||
| type=int, | ||
| required=False, | ||
| default=32768, | ||
| help=( | ||
| "Generation token budget for the judge response (reasoning + scores). For " | ||
| "VLLM, keep this <= --max_model_len (if provided)." | ||
| ), | ||
| ) | ||
| parser.add_argument( | ||
| "--max_model_len", | ||
| type=int, | ||
| required=False, | ||
| default=None, | ||
| help=( | ||
| "Optional total context window for VLLM models (prompt + generation). This is " | ||
| "independent from --max_out_tokens_models/--max_out_tokens_judge, which only cap " | ||
| "generated tokens. This is useful on smaller GPUs to avoid OOM." | ||
| ), | ||
| ) | ||
| parser.add_argument( | ||
| "--chat_template", | ||
| type=str, | ||
| required=False, | ||
| default=None, | ||
| help="Jinja2 chat template string to use instead of the model's tokenizer template. " | ||
| "If not provided, ChatML is used as fallback for models without a chat template.", | ||
| ) | ||
| parser.add_argument( | ||
| "--mt_bench_turns", | ||
| type=str, | ||
| choices=["both", "single", "multi"], | ||
| default="both", | ||
| help="Which MT-Bench turns to evaluate. 'single': only turn 1, " | ||
| "'multi': only turn 2 (with full conversation context), " | ||
| "'both' (default): evaluate both turns.", | ||
| ) | ||
| parser.add_argument( | ||
| "--mt_bench_compatibility", | ||
| type=str, | ||
| choices=["openjury", "fastchat"], | ||
| default="openjury", | ||
| help=( | ||
| "MT-Bench evaluation/generation mode. " | ||
| "'openjury' (default): OpenJury score_A/score_B prompt + softmax preference. " | ||
| "'fastchat': use FastChat/MT-Bench pairwise prompts with [[A]]/[[B]]/[[C]] verdict parsing, " | ||
| "conservative position-bias handling, judge temperature=0, and MT-Bench category temperatures." | ||
| ), | ||
| ) | ||
| parser.add_argument( | ||
| "--engine_kwargs", | ||
| type=str, | ||
| required=False, | ||
| default="{}", | ||
| help=( | ||
| "JSON dict of engine-specific kwargs forwarded to the underlying engine. " | ||
| "Example for vLLM: '{\"tensor_parallel_size\": 2, \"gpu_memory_utilization\": 0.9}'." | ||
| ), | ||
| ) | ||
| args = parser.parse_args() | ||
|
|
||
| try: | ||
| engine_kwargs = ( | ||
| json.loads(args.engine_kwargs) if args.engine_kwargs else {} | ||
| ) | ||
| if not isinstance(engine_kwargs, dict): | ||
| raise ValueError("engine_kwargs must be a JSON object") | ||
| except Exception as e: | ||
| raise SystemExit(f"Failed to parse --engine_kwargs: {e}") | ||
|
|
||
| return cls( | ||
| dataset=args.dataset, | ||
| model_A=args.model_A, | ||
| model_B=args.model_B, | ||
| judge_model=args.judge_model, | ||
| n_instructions=args.n_instructions, | ||
| provide_explanation=args.provide_explanation, | ||
| swap_mode=args.swap_mode, | ||
| ignore_cache=args.ignore_cache, | ||
| use_tqdm=args.use_tqdm, | ||
| truncate_all_input_chars=args.truncate_all_input_chars, | ||
| max_out_tokens_models=args.max_out_tokens_models, | ||
| max_out_tokens_judge=args.max_out_tokens_judge, | ||
| max_model_len=args.max_model_len, | ||
| chat_template=args.chat_template, | ||
| mt_bench_turns=args.mt_bench_turns, | ||
| mt_bench_compatibility=args.mt_bench_compatibility, | ||
| result_folder=args.result_folder, | ||
| engine_kwargs=engine_kwargs, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💪