diff --git a/auto_round/__main__.py b/auto_round/__main__.py index 76a8f73d1..97b34864c 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -367,6 +367,24 @@ def __init__(self, *args, **kwargs): "Options: 'float16', 'bfloat16', 'float32'. " "Should match your hardware capabilities for best performance.", ) + eval_args.add_argument( + "--task_configs", + type=str, + default=None, + help=( + "Optional per-task configuration in JSON or simplified format. " + "Example JSON: " + '\'{"gsm8k_llama": {"apply_chat_template": true, "fewshot_as_multiturn": true}, ' + ' "hellaswag": {"num_fewshot": 10}}\' ' + "You can also provide a JSON file path like 'task_configs.json'." + ), + ) + eval_args.add_argument( + "--disable_thinking", + action="store_true", + help=("whether to disable thinking mode of chat_template."), + ) + eval_args.add_argument("--max_length", default=None, type=int, help="Random seed for reproducibility.") ## ======================= MLLM ======================= mllm_args = self.add_argument_group("Multimodal Large Language Model(MLLM) arguments") @@ -735,6 +753,9 @@ def tune(args): limit=args.limit, batch_size=args.eval_bs, eval_model_dtype=eval_model_dtype, + task_configs=args.task_configs, + disable_thinking=args.disable_thinking, + max_length=args.max_length, ) else: if args.eval_bs is None or args.eval_bs == "auto": @@ -763,11 +784,15 @@ def tune(args): eval_task_by_task( eval_folder, device=device_str, + tokenizer=tokenizer, tasks=args.tasks, batch_size=args.eval_bs, limit=args.limit, eval_model_dtype=eval_model_dtype, mllm=autoround.mllm, # pylint: disable=E1101 + task_configs=args.task_configs, + disable_thinking=args.disable_thinking, + max_length=args.max_length, ) else: from auto_round.eval.evaluation import simple_evaluate @@ -821,6 +846,9 @@ def run_eval(): batch_size=args.eval_bs, trust_remote_code=not args.disable_trust_remote_code, eval_model_dtype=args.eval_model_dtype, + task_configs=args.task_configs, + disable_thinking=args.disable_thinking, + max_length=args.max_length, ) else: eval(args) diff --git a/auto_round/eval/eval_cli.py b/auto_round/eval/eval_cli.py index 009b6458d..76b7065fe 100644 --- a/auto_round/eval/eval_cli.py +++ b/auto_round/eval/eval_cli.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import json import os import time @@ -101,6 +102,24 @@ def __init__(self, *args, **kwargs): choices=["hf", "vllm"], help="Backend to use for model evaluation. Use hf backend for evaluation by default.", ) + self.add_argument( + "--task_configs", + type=str, + default=None, + help=( + "Optional per-task configuration in JSON or simplified format. " + "Example JSON: " + '\'{"gsm8k_llama": {"apply_chat_template": true, "fewshot_as_multiturn": true}, ' + ' "hellaswag": {"num_fewshot": 10}}\' ' + "You can also provide a JSON file path like 'task_configs.json'." + ), + ) + self.add_argument( + "--disable_thinking", + action="store_true", + help=("whether to disable thinking mode of chat_template."), + ) + self.add_argument("--max_length", default=None, type=int, help="max generation length for eval") # vllm related arguments vllm_args = self.add_argument_group("vllm backend arguments") @@ -221,7 +240,34 @@ def eval_task_by_task( eval_model_dtype=None, retry_times=3, mllm=False, + task_configs=None, # e.g. {"gsm8k": {"apply_chat_template": True, "fewshot_as_multiturn": True}} + disable_thinking=False, + max_length=None, # default to align with model's original setting ): + """ + Evaluate each LM-eval task sequentially, with optional per-task overrides. + + Args: + model (str | nn.Module): Model path or loaded model. + device (str): Device id (e.g. "0" or "cuda:0"). + tasks (list[str] | str): Tasks to run, separated by comma. + tokenizer: HuggingFace tokenizer. + batch_size: Eval batch size (default: "auto:8"). + limit: Number of samples or fraction per task. + task_configs (dict): Optional task-specific settings like fewshot/chat. + """ + if isinstance(task_configs, str): + if os.path.isfile(task_configs): + with open(task_configs, "r") as f: + task_configs = json.load(f) + else: + try: + task_configs = json.loads(task_configs) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid --task_configs format: {e}") + elif task_configs is None: + task_configs = {} + set_cuda_visible_devices(device) device_str, parallelism = get_device_and_parallelism(device) @@ -237,6 +283,10 @@ def eval_task_by_task( if batch_size is None: batch_size = "auto:8" + + # ------------------------------- + # Load model (support gguf) + # ------------------------------- is_gguf_file = False if not isinstance(model, str): parallelism = False @@ -265,6 +315,19 @@ def eval_task_by_task( ) model.eval() parallelism = False + + # ------------------------------- + # Build LM-eval model wrapper + # ------------------------------- + if disable_thinking: ## align with fp-quant + from functools import partial + + tokenizer.apply_chat_template = partial(tokenizer.apply_chat_template, enable_thinking=False) + # check the max_length + init_kwargs = {} + if max_length is not None: + init_kwargs["max_length"] = max_length + if mllm: if batch_size is None or batch_size == "auto": logger.warning("hf-multimodal models does not support auto currently, reset eval_bs to 16") @@ -278,6 +341,7 @@ def eval_task_by_task( parallelize=parallelism, trust_remote_code=trust_remote_code, dtype=eval_model_dtype, + **init_kwargs, ) else: hflm = HFLM( @@ -289,6 +353,7 @@ def eval_task_by_task( parallelize=parallelism, trust_remote_code=trust_remote_code, dtype=eval_model_dtype, + **init_kwargs, ) if isinstance(tasks, str): @@ -302,10 +367,28 @@ def eval_task_by_task( st = time.time() for task in tasks: + task_cfg = task_configs.get(task, {}) + num_fewshot = task_cfg.get("num_fewshot") + apply_chat_template = task_cfg.get("apply_chat_template", False) + batch_size = task_cfg.get("batch_size", batch_size) + fewshot_as_multiturn = task_cfg.get("fewshot_as_multiturn", False) + logger.info(f"=== Running task: {task} ===") + logger.info( + f"Task config: fewshot={num_fewshot}, apply_chat_template={apply_chat_template}," + f"fewshot_as_multiturn={fewshot_as_multiturn}, batch_size={batch_size}" + ) while retry_times: try: res = lm_simple_evaluate( - model=hflm, model_args=None, device=device_str, tasks=task, batch_size=batch_size, limit=limit + model=hflm, + model_args=None, + device=device_str, + tasks=task, + batch_size=batch_size, + limit=limit, + num_fewshot=num_fewshot, + apply_chat_template=apply_chat_template, + fewshot_as_multiturn=fewshot_as_multiturn, ) break except Exception as e: @@ -317,7 +400,15 @@ def eval_task_by_task( hflm.batch_sizes[k] = max(v // 2, 1) logger.warning(f"Out of memory, reset batch_size to {hflm.batch_sizes} and re-try.") res = lm_simple_evaluate( - model=hflm, model_args=None, device=device_str, tasks=task, batch_size=1, limit=limit + model=hflm, + model_args=None, + device=device_str, + tasks=task, + batch_size=1, + limit=limit, + num_fewshot=num_fewshot, + apply_chat_template=apply_chat_template, + fewshot_as_multiturn=fewshot_as_multiturn, ) hflm.batch_sizes = ori_batch_sizes except Exception as e: