diff --git a/requirements-dev.txt b/requirements-dev.txt index 5a91580ee74..46525d5fe74 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -28,11 +28,12 @@ jieba==0.42.1 rouge==1.0.1 pytest-rerunfailures ruff==0.9.4 -lm_eval[api]==0.4.8 +lm_eval[api]==0.4.9.2 docstring_parser genai-perf==0.0.13 opentelemetry-sdk>=1.26.0 opentelemetry-api>=1.26.0 opentelemetry-exporter-otlp>=1.26.0 opentelemetry-semantic-conventions-ai>=0.4.1 +fuzzywuzzy==0.18.0 aiperf==0.3.0 diff --git a/tensorrt_llm/commands/eval.py b/tensorrt_llm/commands/eval.py index 76c6c63eb81..98cea6e7cd6 100644 --- a/tensorrt_llm/commands/eval.py +++ b/tensorrt_llm/commands/eval.py @@ -21,7 +21,8 @@ from .. import LLM as PyTorchLLM from .._tensorrt_engine import LLM from ..evaluate import (GSM8K, MMLU, MMMU, CnnDailymail, GPQADiamond, - GPQAExtended, GPQAMain, JsonModeEval, LongBenchV2) + GPQAExtended, GPQAMain, JsonModeEval, LongBenchV1, + LongBenchV2) from ..llmapi import BuildConfig, KvCacheConfig from ..llmapi.llm_utils import update_llm_args_with_extra_options from ..logger import logger, severity_map @@ -181,6 +182,7 @@ def main(ctx, model: str, tokenizer: Optional[str], main.add_command(GPQAExtended.command) main.add_command(JsonModeEval.command) main.add_command(MMMU.command) +main.add_command(LongBenchV1.command) main.add_command(LongBenchV2.command) if __name__ == "__main__": diff --git a/tensorrt_llm/evaluate/__init__.py b/tensorrt_llm/evaluate/__init__.py index b30a99a4aeb..364eb44f608 100755 --- a/tensorrt_llm/evaluate/__init__.py +++ b/tensorrt_llm/evaluate/__init__.py @@ -15,11 +15,12 @@ from .cnn_dailymail import CnnDailymail from .json_mode_eval import JsonModeEval -from .lm_eval import GSM8K, MMMU, GPQADiamond, GPQAExtended, GPQAMain +from .lm_eval import (GSM8K, MMMU, GPQADiamond, GPQAExtended, GPQAMain, + LongBenchV1) from .longbench_v2 import LongBenchV2 from .mmlu import MMLU __all__ = [ "CnnDailymail", "MMLU", "GSM8K", "GPQADiamond", "GPQAMain", "GPQAExtended", - "JsonModeEval", "MMMU", "LongBenchV2" + "JsonModeEval", "MMMU", "LongBenchV1", "LongBenchV2" ] diff --git a/tensorrt_llm/evaluate/lm_eval.py b/tensorrt_llm/evaluate/lm_eval.py index b1e1671bc4c..b228498b56c 100644 --- a/tensorrt_llm/evaluate/lm_eval.py +++ b/tensorrt_llm/evaluate/lm_eval.py @@ -100,10 +100,23 @@ def _get_sampling_params(self, gen_kwargs: dict) -> SamplingParams: "max_gen_toks": "max_tokens", "until": "stop", } + # IMPORTANT: + # lm-evaluation-harness controls generation primarily via per-task gen_kwargs. + # For example, the `local-completions` model wrapper uses: + # max_tokens <- gen_kwargs["max_tokens"] or gen_kwargs["max_gen_toks"] or _max_gen_toks + # temperature <- gen_kwargs.get("temperature", 0) + # stop <- gen_kwargs.get("until", ...) + # See: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/models/openai_completions.py + if self.sampling_params is None: - sampling_params = SamplingParams() + sampling_params = SamplingParams( + max_tokens=gen_kwargs.get("max_gen_toks", 256), + temperature=gen_kwargs.get("temperature", 0), + stop=gen_kwargs.get("until", None), + ) else: sampling_params = copy.deepcopy(self.sampling_params) + for lm_eval_key, trtllm_key in params_mapping.items(): value = gen_kwargs.pop(lm_eval_key, None) if value is not None: @@ -714,3 +727,156 @@ def command(ctx, **kwargs) -> None: kwargs[ "stop"] = "<|endoftext|>" # NOTE: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/mmmu/_template_yaml#L10 MMMU.command_harness(ctx, **kwargs) + + +class LongBenchV1(LmEvalEvaluator): + """ + LongBench v1 evaluation via lm-evaluation-harness. + + Notes: + - In lm-eval, `longbench` is typically a *group task* that expands into many + subtasks. The base `LmEvalEvaluator.evaluate()` assumes a single task + key exists in `results["results"][task_name]`, so we override evaluation + to aggregate over subtasks. + """ + + def __init__(self, **kwargs): + super().__init__("longbench", **kwargs) + + @staticmethod + def _flatten_task_dict(task_dict: dict) -> List[str]: + names: List[str] = [] + for k, v in task_dict.items(): + if isinstance(v, dict): + names.extend(LongBenchV1._flatten_task_dict(v)) + else: + names.append(k) + return names + + @staticmethod + def _get_group_score(metrics: Dict[str, Any], + *, + preferred_filter: str = "none") -> Optional[float]: + """ + lm-eval stores group metrics as "," (e.g., "score,none"). + Prefer "score,none" (matches printed table), otherwise accept any + "score," key. + """ + if not isinstance(metrics, dict): + return None + + preferred_key = f"score,{preferred_filter}" + v = metrics.get(preferred_key, None) + if isinstance(v, (int, float)): + return float(v) + + return None + + def evaluate(self, + llm: Union[LLM, PyTorchLLM], + sampling_params: Optional[SamplingParams] = None, + streaming: bool = False) -> float: + import lm_eval + + lm_cls = MultimodalLmEvalWrapper if self.MULTIMODAL else LmEvalWrapper + results = lm_eval.evaluate( + lm=lm_cls(llm, + sampling_params=sampling_params, + streaming=streaming, + chat_template_kwargs=self.chat_template_kwargs), + task_dict=self.task_dict, + limit=self.num_samples, + apply_chat_template=self.apply_chat_template, + fewshot_as_multiturn=self.fewshot_as_multiturn, + system_instruction=self.system_prompt) + + logger.info( + f"lm-eval {self.task_name} results:\n{lm_eval.utils.make_table(results)}" + ) + + # LongBench is a group task in lm-eval. lm-eval already computes subgroup + # "score" values (e.g., `longbench_fewshot`, `longbench_single`, ...). + # To keep this implementation simple and aligned with the printed table, + # we compute the final LongBench score as the unweighted mean of subgroup + # scores. + group_results: Dict[str, Dict[str, Any]] = results.get("groups", {}) + subgroup_names = results.get("group_subtasks", + {}).get(self.task_name, []) + if not subgroup_names: + raise KeyError( + f"lm-eval did not provide subgroup list for group '{self.task_name}'. " + "Expected `results['group_subtasks'][task_name]` to exist.") + + subgroup_scores: List[float] = [] + missing: List[str] = [] + for name in subgroup_names: + m = group_results.get(name, None) + score = self._get_group_score(m) + if score is None: + missing.append(name) + else: + subgroup_scores.append(score) + + if not subgroup_scores: + raise KeyError( + f"lm-eval did not provide subgroup 'score' metrics for '{self.task_name}'. " + f"Missing subgroups: {missing[:10]}") + + result_acc = float(np.mean(subgroup_scores)) * 100 + logger.info( + f"lm-eval {self.task_name} average 'score' across {len(subgroup_scores)} subgroups: {result_acc:.2f}" + ) + return result_acc + + @click.command("longbench_v1") + @click.option( + "--dataset_path", + type=str, + default=None, + help= + "The path to LongBench dataset. If unspecified, the dataset is downloaded from HF hub." + ) + @click.option( + "--num_samples", + type=int, + default=None, + help="Number of samples to run the evaluation; None means full dataset." + ) + @click.option("--random_seed", + type=int, + default=0, + help="Random seed for dataset processing.") + @click.option("--apply_chat_template", + type=click.BOOL, + default=True, + show_default=True, + help="Whether to apply chat template.") + @click.option( + "--chat_template_kwargs", + type=str, + default=None, + callback=lambda ctx, param, value: json.loads(value) if value else None, + help= + 'Chat template kwargs as JSON string, e.g., \'{"thinking_budget": 0}\'') + @click.option("--system_prompt", + type=str, + default=None, + help="System prompt.") + @click.pass_context + @staticmethod + def command(ctx, **kwargs) -> None: + llm: Union[LLM, PyTorchLLM] = ctx.obj + + evaluator = LongBenchV1( + dataset_path=kwargs.pop("dataset_path", None), + num_samples=kwargs.pop("num_samples", None), + random_seed=kwargs.pop("random_seed", 0), + apply_chat_template=kwargs.pop("apply_chat_template", True), + system_prompt=kwargs.pop("system_prompt", None), + chat_template_kwargs=kwargs.pop("chat_template_kwargs", None)) + + # Let lm-eval task configs control sampling via gen_kwargs. + sampling_params = None + + evaluator.evaluate(llm, sampling_params) + llm.shutdown() diff --git a/tests/integration/defs/accuracy/accuracy_core.py b/tests/integration/defs/accuracy/accuracy_core.py index 0c03a8422a8..f96ac7d6184 100644 --- a/tests/integration/defs/accuracy/accuracy_core.py +++ b/tests/integration/defs/accuracy/accuracy_core.py @@ -456,6 +456,34 @@ class LongBenchV2(AccuracyTask): ) +class LongBenchV1(AccuracyTask): + DATASET = "longbench_v1" + # Keep the dataset local like other accuracy tasks (avoid HF hub traffic). + # Expected to be populated in CI image / test environment. + DATASET_DIR = f"{llm_models_root()}/datasets/Xnhyacinth/LongBench" + + # NOTE: LongBench v1 is driven by lm-evaluation-harness task configs. + # We intentionally do not pin dataset_path here (it can be resolved by lm-eval + # via HF Hub or local cache). + ALPHA = 0.05 + BETA = 0.2 + SIGMA = 50.0 + + # Full sample + NUM_SAMPLES = 4750 + + # These are used by AccuracyTask to construct SamplingParams defaults. + # LongBench v1 tasks provide per-task gen_kwargs, so these are mainly a safe fallback. + MAX_BATCH_SIZE = 256 + MAX_INPUT_LEN = 128000 + MAX_OUTPUT_LEN = 1024 + + EVALUATOR_CLS = tensorrt_llm.evaluate.LongBenchV1 + EVALUATOR_KWARGS = dict(dataset_path=DATASET_DIR, + random_seed=0, + apply_chat_template=True) + + class CliFlowAccuracyTestHarness: # Model MODEL_NAME = None diff --git a/tests/integration/defs/accuracy/references/longbench_v1.yaml b/tests/integration/defs/accuracy/references/longbench_v1.yaml new file mode 100644 index 00000000000..c638ab92bb8 --- /dev/null +++ b/tests/integration/defs/accuracy/references/longbench_v1.yaml @@ -0,0 +1,8 @@ +Qwen3/Qwen3-30B-A3B-Instruct-2507: + # Skip Softmax Attention ref accuracy + - extra_acc_spec: "target_sparsity=0.0" + accuracy: 47.22 + - extra_acc_spec: "target_sparsity=0.5" + accuracy: 47.22 + - extra_acc_spec: "target_sparsity=0.9" + accuracy: 45.90 diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index ab9c1591c17..8d8a39a433d 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -55,7 +55,7 @@ def patched_start_mpi_pool(self): EagleDecodingConfig, KvCacheConfig, MoeConfig, MTPDecodingConfig, NGramDecodingConfig, RocketSparseAttentionConfig, SamplingParams, - TorchCompileConfig) + SkipSoftmaxAttentionConfig, TorchCompileConfig) from tensorrt_llm.quantization import QuantAlgo from ..conftest import (get_device_count, get_device_memory, llm_models_root, @@ -64,7 +64,7 @@ def patched_start_mpi_pool(self): skip_pre_hopper, skip_ray) from .accuracy_core import (GSM8K, MMLU, CnnDailymail, GPQADiamond, JsonModeEval, LlmapiAccuracyTestHarness, - LongBenchV2) + LongBenchV1, LongBenchV2) class TestLlama3_1_8B(LlmapiAccuracyTestHarness): @@ -3866,6 +3866,46 @@ def test_nvfp4_4gpus(self, tp_size, pp_size, ep_size, attention_dp, task.evaluate(llm) +class TestQwen3_30B_A3B_Instruct_2507(LlmapiAccuracyTestHarness): + MODEL_NAME = "Qwen3/Qwen3-30B-A3B-Instruct-2507" + MODEL_PATH = f"{llm_models_root()}/{MODEL_NAME}" + + @skip_pre_hopper + # @pytest.mark.skip_less_device_memory(140000) # Only test for H200, B200 + @pytest.mark.parametrize( + "target_sparsity,thr_prefill,thr_decode", + [ + (0.0, 0.0, 0.0), + (0.5, 85.97384174442398, 55.48258322852407), + (0.9, 1418.142868970396, 863.147841750025), + ], + ids=[ + "target_sparsity_0.0", "target_sparsity_0.5", "target_sparsity_0.9" + ], + ) + def test_skip_softmax_attention(self, target_sparsity: float, + thr_prefill: float, thr_decode: float): + sparse_attention_config = SkipSoftmaxAttentionConfig( + threshold_scale_factor={ + "prefill": thr_prefill, + "decode": thr_decode, + }) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.85) + + if get_sm_version() >= 100: + pytest.skip("Bug to be fixed on Blackwell") + + with LLM(self.MODEL_PATH, + attn_backend="TRTLLM", + max_batch_size=256, + max_num_tokens=100000, + kv_cache_config=kv_cache_config, + sparse_attention_config=sparse_attention_config) as llm: + task = LongBenchV1(self.MODEL_NAME) + task.evaluate(llm, + extra_acc_spec=f"target_sparsity={target_sparsity}") + + class TestPhi4MiniInstruct(LlmapiAccuracyTestHarness): MODEL_NAME = "microsoft/Phi-4-mini-instruct" MODEL_PATH = f"{llm_models_root()}/Phi-4-mini-instruct" diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 6ec17f1af91..aa9f0ffc712 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -55,6 +55,9 @@ l0_b200: - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[mxfp8-latency-TRTLLM] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[mxfp8-latency-CUTLASS] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a16_mxfp4[latency-TRTLLM] + - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.0] + - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.5] + - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.9] - accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp1-cutlass] - disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyLlama-1.1B-Chat-v1.0] # nvbugs 5300551 - test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-8B] diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 26767235ace..7a691ee8fc7 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -76,6 +76,9 @@ l0_h100: - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_dummy_load_format + - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.0] TIMEOUT (90) + - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.5] TIMEOUT (90) + - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.9] TIMEOUT (90) - accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[enable_chunked_prefill=False-eagle3_one_model=False] - accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[enable_chunked_prefill=True-eagle3_one_model=True] - accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[enable_chunked_prefill=False-eagle3_one_model=True]