Skip to content

Commit 35a6e86

Browse files
Fix CLI regressions (#3449)
* fix regressions in cli * linting? * still allow space delimited * use `SplitArgs` for --tasks; use nargs="+"; tests --------- Co-authored-by: Baber <[email protected]>
1 parent cdb4253 commit 35a6e86

File tree

6 files changed

+227
-38
lines changed

6 files changed

+227
-38
lines changed

lm_eval/_cli/run.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from lm_eval._cli.subcommand import SubCommand
99
from lm_eval._cli.utils import (
1010
MergeDictAction,
11+
SplitArgs,
1112
_int_or_none_list_arg_type,
1213
request_caching_arg_to_dict,
1314
try_parse_json,
@@ -65,11 +66,11 @@ def _add_args(self) -> None:
6566
"--tasks",
6667
"-t",
6768
default=None,
68-
type=str,
69-
nargs="*",
69+
nargs="+",
7070
metavar="<task>",
71+
action=SplitArgs,
7172
help=textwrap.dedent("""
72-
Space or Comma-separated list of task names or groupings.
73+
Space (or comma-separated) list of task names or groupings.
7374
Use 'lm-eval list tasks' to see all available tasks.
7475
""").strip(),
7576
)
@@ -85,7 +86,7 @@ def _add_args(self) -> None:
8586
"--model_args",
8687
"-a",
8788
default=None,
88-
nargs="*",
89+
nargs="+",
8990
action=MergeDictAction,
9091
metavar="<arg>",
9192
help="Model arguments as 'key=val,key2=val2' or `key=val` `key2=val2`",
@@ -153,7 +154,7 @@ def _add_args(self) -> None:
153154
eval_group.add_argument(
154155
"--gen_kwargs",
155156
default=None,
156-
nargs="*",
157+
nargs="+",
157158
action=MergeDictAction,
158159
metavar="<arg>",
159160
help=textwrap.dedent(
@@ -265,23 +266,23 @@ def _add_args(self) -> None:
265266
logging_group.add_argument(
266267
"--wandb_args",
267268
default=None,
268-
nargs="*",
269+
nargs="+",
269270
action=MergeDictAction,
270271
metavar="<args>",
271272
help="Weights & Biases init arguments key=val key2=val2",
272273
)
273274
logging_group.add_argument(
274275
"--wandb_config_args",
275276
default=None,
276-
nargs="*",
277+
nargs="+",
277278
action=MergeDictAction,
278279
metavar="<args>",
279280
help="Weights & Biases config arguments key=val key2=val2",
280281
)
281282
logging_group.add_argument(
282283
"--hf_hub_log_args",
283284
default=None,
284-
nargs="*",
285+
nargs="+",
285286
action=MergeDictAction,
286287
metavar="<args>",
287288
help="Hugging Face Hub logging arguments key=val key2=val2",

lm_eval/_cli/utils.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,17 @@ def try_parse_json(value: str | dict[str, Any] | None) -> str | dict[str, Any] |
2121
if "{" in value:
2222
raise ValueError(
2323
f"Invalid JSON: {value}. Hint: Use double quotes for JSON strings."
24-
)
24+
) from None
2525
return value
2626

2727

2828
def _int_or_none_list_arg_type(
2929
min_len: int, max_len: int, defaults: str, value: str, split_char: str = ","
3030
) -> list[int | None]:
3131
"""Parses a string of integers or 'None' values separated by a specified character into a list.
32-
Validates the number of items against specified minimum and maximum lengths and fills missing values with defaults."""
32+
33+
Validates the number of items against specified minimum and maximum lengths and fills missing values with defaults.
34+
"""
3335

3436
def parse_value(item):
3537
"""Parses an individual item, converting it to an integer or `None`."""
@@ -39,7 +41,7 @@ def parse_value(item):
3941
try:
4042
return int(item)
4143
except ValueError:
42-
raise ValueError(f"{item} is not an integer or None")
44+
raise ValueError(f"{item} is not an integer or None") from None
4345

4446
items = [parse_value(v) for v in value.split(split_char)]
4547
num_items = len(items)
@@ -109,6 +111,7 @@ def key_val_to_dict(args: str) -> dict[str, Any]:
109111
res = {}
110112
if not args:
111113
return res
114+
112115
for k, v in (item.split("=") for item in args.split(",")):
113116
v = handle_cli_value_string(v)
114117
if k in res:
@@ -128,13 +131,34 @@ def __call__(
128131
option_string: str | None = None,
129132
) -> None:
130133
current = vars(namespace).setdefault(self.dest, {}) or {}
131-
if values:
132-
for v in values:
133-
v = key_val_to_dict(v)
134-
if overlap := current.keys() & v.keys():
135-
eval_logger.warning(
136-
f"{option_string or self.dest}: Overwriting key {', '.join(f'{k}: {current[k]!r} -> {v[k]!r}' for k in overlap)}"
137-
)
138-
139-
current.update(v)
134+
135+
if not values:
136+
return
137+
138+
# e.g. parses `{"pretrained":"/models/openai_gpt-oss-20b","dtype":"auto","chat_template_args":{"reasoning_effort":"low"},"enable_thinking": true,"think_end_token":"<|message|>"}`.
139+
result = try_parse_json(values[0])
140+
141+
if isinstance(result, dict):
142+
current = {**current, **result}
143+
else:
144+
# e.g. parses `max_gen_toks=8000`
145+
if values:
146+
for v in values:
147+
v = key_val_to_dict(v)
148+
if overlap := current.keys() & v.keys():
149+
eval_logger.warning(
150+
rf"{option_string or self.dest}: Overwriting {', '.join(f'{k}: {current[k]!r} -> {v[k]!r}' for k in overlap)}"
151+
)
152+
current.update(v)
153+
140154
setattr(namespace, self.dest, current)
155+
156+
157+
class SplitArgs(argparse.Action):
158+
def __call__(self, parser, namespace, values, option_string=None):
159+
items = getattr(namespace, self.dest) or []
160+
values = values or []
161+
assert values, f"--{self.dest} passed without any values"
162+
for v in values:
163+
items.extend(v.split(","))
164+
setattr(namespace, self.dest, items)

lm_eval/config/evaluate_config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def from_cli(cls, namespace: Namespace) -> "EvaluatorConfig":
203203

204204
# Load and merge YAML config if provided
205205
if used_config := getattr(namespace, "config", None):
206-
config.update(cls.load_yaml_config(cast(str, used_config)))
206+
config.update(cls.load_yaml_config(cast("str", used_config)))
207207

208208
# Override with CLI args (only truthy values or 0, exclude non-config args)
209209
excluded_args = {"command", "func"} # argparse internal args
@@ -320,7 +320,7 @@ def _process_arguments(self):
320320
try:
321321
self.samples = json.loads(self.samples)
322322
except json.JSONDecodeError:
323-
if (samples_path := Path(cast(str, self.samples))).is_file():
323+
if (samples_path := Path(cast("str", self.samples))).is_file():
324324
self.samples = json.loads(samples_path.read_text())
325325

326326
# Set up metadata by merging model_args and metadata.
@@ -358,8 +358,11 @@ def process_tasks(self, metadata: dict | None = None) -> "TaskManager":
358358
)
359359

360360
# Normalize tasks to a list
361+
# We still allow tasks in the form task1,task2
361362
task_list = (
362-
self.tasks.split(",") if isinstance(self.tasks, str) else list(self.tasks)
363+
self.tasks.split(",")
364+
if isinstance(self.tasks, str)
365+
else [t for task in self.tasks for t in task.split(",")]
363366
)
364367

365368
# Handle directory input

lm_eval/evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def simple_evaluate(
209209
"No tasks specified, or no tasks found. Please verify the task names."
210210
)
211211

212-
if gen_kwargs is not None:
212+
if gen_kwargs:
213213
if isinstance(gen_kwargs, str):
214214
gen_kwargs = simple_parse_args_string(gen_kwargs)
215215
eval_logger.warning(

lm_eval/evaluator_utils.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import math
44
import pathlib
55
import sys
6-
from typing import List, Optional, Tuple, Union
76

87
from lm_eval.api.group import ConfigurableGroup
98
from lm_eval.api.metrics import (
@@ -139,7 +138,7 @@ def __repr__(self):
139138
)
140139

141140

142-
def get_task_list(task_dict: dict) -> List[TaskOutput]:
141+
def get_task_list(task_dict: dict) -> list[TaskOutput]:
143142
outputs = []
144143
for task_name, task_obj in task_dict.items():
145144
if isinstance(task_obj, dict):
@@ -210,7 +209,7 @@ def print_writeout(task) -> None:
210209
eval_logger.info(f"Request: {str(inst)}")
211210

212211

213-
def get_sample_size(task, limit: Optional[int]) -> Union[int, None]:
212+
def get_sample_size(task, limit: int | None) -> int | None:
214213
if limit is not None:
215214
limit = (
216215
int(math.ceil(len(task.eval_docs) * limit)) if limit < 1.0 else int(limit)
@@ -223,7 +222,7 @@ def prepare_print_tasks(
223222
results: dict,
224223
task_depth=0,
225224
group_depth=0,
226-
) -> Tuple[dict, dict]:
225+
) -> tuple[dict, dict]:
227226
"""
228227
@param task_dict: Dictionary representing the group hierarchy of tasks. Each key is a group name and its
229228
value is a list of task names.
@@ -311,8 +310,8 @@ def _sort_task_dict(task_dict):
311310

312311

313312
def consolidate_results(
314-
eval_tasks: List[TaskOutput],
315-
) -> Tuple[dict, dict, dict, dict, dict, dict]:
313+
eval_tasks: list[TaskOutput],
314+
) -> tuple[dict, dict, dict, dict, dict, dict]:
316315
"""
317316
@param eval_tasks: list(TaskOutput).
318317
@return: A tuple containing the consolidated results, samples, configs, versions, and num_fewshot.
@@ -379,7 +378,7 @@ def consolidate_group_results(
379378
task_root=None,
380379
show_group_table=False,
381380
task_aggregation_list=None,
382-
) -> Tuple[dict, dict, bool, Union[None,]]:
381+
) -> tuple[dict, dict, bool, None]:
383382
"""
384383
(Recursively) calculates groups' aggregated metrics and updates the results and versions dictionaries with this info.
385384
@@ -548,7 +547,7 @@ def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
548547

549548

550549
@positional_deprecated
551-
def run_task_tests(task_list: List[str]):
550+
def run_task_tests(task_list: list[str]):
552551
"""
553552
Find the package root and run the tests for the given tasks
554553
"""

0 commit comments

Comments
 (0)