Skip to content

Commit 1f0365d

Browse files
authored
[None][infra] Add LongBenchV1 to trtllm-eval. (#10265)
Signed-off-by: Bo Li <[email protected]>
1 parent 6732c76 commit 1f0365d

File tree

9 files changed

+259
-7
lines changed

9 files changed

+259
-7
lines changed

requirements-dev.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,12 @@ jieba==0.42.1
2828
rouge==1.0.1
2929
pytest-rerunfailures
3030
ruff==0.9.4
31-
lm_eval[api]==0.4.8
31+
lm_eval[api]==0.4.9.2
3232
docstring_parser
3333
genai-perf==0.0.13
3434
opentelemetry-sdk>=1.26.0
3535
opentelemetry-api>=1.26.0
3636
opentelemetry-exporter-otlp>=1.26.0
3737
opentelemetry-semantic-conventions-ai>=0.4.1
38+
fuzzywuzzy==0.18.0
3839
aiperf==0.3.0

tensorrt_llm/commands/eval.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
from .. import LLM as PyTorchLLM
2222
from .._tensorrt_engine import LLM
2323
from ..evaluate import (GSM8K, MMLU, MMMU, CnnDailymail, GPQADiamond,
24-
GPQAExtended, GPQAMain, JsonModeEval, LongBenchV2)
24+
GPQAExtended, GPQAMain, JsonModeEval, LongBenchV1,
25+
LongBenchV2)
2526
from ..llmapi import BuildConfig, KvCacheConfig
2627
from ..llmapi.llm_utils import update_llm_args_with_extra_options
2728
from ..logger import logger, severity_map
@@ -184,6 +185,7 @@ def main(ctx, model: str, tokenizer: Optional[str],
184185
main.add_command(GPQAExtended.command)
185186
main.add_command(JsonModeEval.command)
186187
main.add_command(MMMU.command)
188+
main.add_command(LongBenchV1.command)
187189
main.add_command(LongBenchV2.command)
188190

189191
if __name__ == "__main__":

tensorrt_llm/evaluate/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515

1616
from .cnn_dailymail import CnnDailymail
1717
from .json_mode_eval import JsonModeEval
18-
from .lm_eval import GSM8K, MMMU, GPQADiamond, GPQAExtended, GPQAMain
18+
from .lm_eval import (GSM8K, MMMU, GPQADiamond, GPQAExtended, GPQAMain,
19+
LongBenchV1)
1920
from .longbench_v2 import LongBenchV2
2021
from .mmlu import MMLU
2122

2223
__all__ = [
2324
"CnnDailymail", "MMLU", "GSM8K", "GPQADiamond", "GPQAMain", "GPQAExtended",
24-
"JsonModeEval", "MMMU", "LongBenchV2"
25+
"JsonModeEval", "MMMU", "LongBenchV1", "LongBenchV2"
2526
]

tensorrt_llm/evaluate/lm_eval.py

Lines changed: 167 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,23 @@ def _get_sampling_params(self, gen_kwargs: dict) -> SamplingParams:
100100
"max_gen_toks": "max_tokens",
101101
"until": "stop",
102102
}
103+
# IMPORTANT:
104+
# lm-evaluation-harness controls generation primarily via per-task gen_kwargs.
105+
# For example, the `local-completions` model wrapper uses:
106+
# max_tokens <- gen_kwargs["max_tokens"] or gen_kwargs["max_gen_toks"] or _max_gen_toks
107+
# temperature <- gen_kwargs.get("temperature", 0)
108+
# stop <- gen_kwargs.get("until", ...)
109+
# See: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/models/openai_completions.py
110+
103111
if self.sampling_params is None:
104-
sampling_params = SamplingParams()
112+
sampling_params = SamplingParams(
113+
max_tokens=gen_kwargs.get("max_gen_toks", 256),
114+
temperature=gen_kwargs.get("temperature", 0),
115+
stop=gen_kwargs.get("until", None),
116+
)
105117
else:
106118
sampling_params = copy.deepcopy(self.sampling_params)
119+
107120
for lm_eval_key, trtllm_key in params_mapping.items():
108121
value = gen_kwargs.pop(lm_eval_key, None)
109122
if value is not None:
@@ -714,3 +727,156 @@ def command(ctx, **kwargs) -> None:
714727
kwargs[
715728
"stop"] = "<|endoftext|>" # NOTE: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/mmmu/_template_yaml#L10
716729
MMMU.command_harness(ctx, **kwargs)
730+
731+
732+
class LongBenchV1(LmEvalEvaluator):
733+
"""
734+
LongBench v1 evaluation via lm-evaluation-harness.
735+
736+
Notes:
737+
- In lm-eval, `longbench` is typically a *group task* that expands into many
738+
subtasks. The base `LmEvalEvaluator.evaluate()` assumes a single task
739+
key exists in `results["results"][task_name]`, so we override evaluation
740+
to aggregate over subtasks.
741+
"""
742+
743+
def __init__(self, **kwargs):
744+
super().__init__("longbench", **kwargs)
745+
746+
@staticmethod
747+
def _flatten_task_dict(task_dict: dict) -> List[str]:
748+
names: List[str] = []
749+
for k, v in task_dict.items():
750+
if isinstance(v, dict):
751+
names.extend(LongBenchV1._flatten_task_dict(v))
752+
else:
753+
names.append(k)
754+
return names
755+
756+
@staticmethod
757+
def _get_group_score(metrics: Dict[str, Any],
758+
*,
759+
preferred_filter: str = "none") -> Optional[float]:
760+
"""
761+
lm-eval stores group metrics as "<metric>,<filter>" (e.g., "score,none").
762+
Prefer "score,none" (matches printed table), otherwise accept any
763+
"score,<filter>" key.
764+
"""
765+
if not isinstance(metrics, dict):
766+
return None
767+
768+
preferred_key = f"score,{preferred_filter}"
769+
v = metrics.get(preferred_key, None)
770+
if isinstance(v, (int, float)):
771+
return float(v)
772+
773+
return None
774+
775+
def evaluate(self,
776+
llm: Union[LLM, PyTorchLLM],
777+
sampling_params: Optional[SamplingParams] = None,
778+
streaming: bool = False) -> float:
779+
import lm_eval
780+
781+
lm_cls = MultimodalLmEvalWrapper if self.MULTIMODAL else LmEvalWrapper
782+
results = lm_eval.evaluate(
783+
lm=lm_cls(llm,
784+
sampling_params=sampling_params,
785+
streaming=streaming,
786+
chat_template_kwargs=self.chat_template_kwargs),
787+
task_dict=self.task_dict,
788+
limit=self.num_samples,
789+
apply_chat_template=self.apply_chat_template,
790+
fewshot_as_multiturn=self.fewshot_as_multiturn,
791+
system_instruction=self.system_prompt)
792+
793+
logger.info(
794+
f"lm-eval {self.task_name} results:\n{lm_eval.utils.make_table(results)}"
795+
)
796+
797+
# LongBench is a group task in lm-eval. lm-eval already computes subgroup
798+
# "score" values (e.g., `longbench_fewshot`, `longbench_single`, ...).
799+
# To keep this implementation simple and aligned with the printed table,
800+
# we compute the final LongBench score as the unweighted mean of subgroup
801+
# scores.
802+
group_results: Dict[str, Dict[str, Any]] = results.get("groups", {})
803+
subgroup_names = results.get("group_subtasks",
804+
{}).get(self.task_name, [])
805+
if not subgroup_names:
806+
raise KeyError(
807+
f"lm-eval did not provide subgroup list for group '{self.task_name}'. "
808+
"Expected `results['group_subtasks'][task_name]` to exist.")
809+
810+
subgroup_scores: List[float] = []
811+
missing: List[str] = []
812+
for name in subgroup_names:
813+
m = group_results.get(name, None)
814+
score = self._get_group_score(m)
815+
if score is None:
816+
missing.append(name)
817+
else:
818+
subgroup_scores.append(score)
819+
820+
if not subgroup_scores:
821+
raise KeyError(
822+
f"lm-eval did not provide subgroup 'score' metrics for '{self.task_name}'. "
823+
f"Missing subgroups: {missing[:10]}")
824+
825+
result_acc = float(np.mean(subgroup_scores)) * 100
826+
logger.info(
827+
f"lm-eval {self.task_name} average 'score' across {len(subgroup_scores)} subgroups: {result_acc:.2f}"
828+
)
829+
return result_acc
830+
831+
@click.command("longbench_v1")
832+
@click.option(
833+
"--dataset_path",
834+
type=str,
835+
default=None,
836+
help=
837+
"The path to LongBench dataset. If unspecified, the dataset is downloaded from HF hub."
838+
)
839+
@click.option(
840+
"--num_samples",
841+
type=int,
842+
default=None,
843+
help="Number of samples to run the evaluation; None means full dataset."
844+
)
845+
@click.option("--random_seed",
846+
type=int,
847+
default=0,
848+
help="Random seed for dataset processing.")
849+
@click.option("--apply_chat_template",
850+
type=click.BOOL,
851+
default=True,
852+
show_default=True,
853+
help="Whether to apply chat template.")
854+
@click.option(
855+
"--chat_template_kwargs",
856+
type=str,
857+
default=None,
858+
callback=lambda ctx, param, value: json.loads(value) if value else None,
859+
help=
860+
'Chat template kwargs as JSON string, e.g., \'{"thinking_budget": 0}\'')
861+
@click.option("--system_prompt",
862+
type=str,
863+
default=None,
864+
help="System prompt.")
865+
@click.pass_context
866+
@staticmethod
867+
def command(ctx, **kwargs) -> None:
868+
llm: Union[LLM, PyTorchLLM] = ctx.obj
869+
870+
evaluator = LongBenchV1(
871+
dataset_path=kwargs.pop("dataset_path", None),
872+
num_samples=kwargs.pop("num_samples", None),
873+
random_seed=kwargs.pop("random_seed", 0),
874+
apply_chat_template=kwargs.pop("apply_chat_template", True),
875+
system_prompt=kwargs.pop("system_prompt", None),
876+
chat_template_kwargs=kwargs.pop("chat_template_kwargs", None))
877+
878+
# Let lm-eval task configs control sampling via gen_kwargs.
879+
sampling_params = None
880+
881+
evaluator.evaluate(llm, sampling_params)
882+
llm.shutdown()

tests/integration/defs/accuracy/accuracy_core.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,34 @@ class LongBenchV2(AccuracyTask):
456456
)
457457

458458

459+
class LongBenchV1(AccuracyTask):
460+
DATASET = "longbench_v1"
461+
# Keep the dataset local like other accuracy tasks (avoid HF hub traffic).
462+
# Expected to be populated in CI image / test environment.
463+
DATASET_DIR = f"{llm_models_root()}/datasets/Xnhyacinth/LongBench"
464+
465+
# NOTE: LongBench v1 is driven by lm-evaluation-harness task configs.
466+
# We intentionally do not pin dataset_path here (it can be resolved by lm-eval
467+
# via HF Hub or local cache).
468+
ALPHA = 0.05
469+
BETA = 0.2
470+
SIGMA = 50.0
471+
472+
# Full sample
473+
NUM_SAMPLES = 4750
474+
475+
# These are used by AccuracyTask to construct SamplingParams defaults.
476+
# LongBench v1 tasks provide per-task gen_kwargs, so these are mainly a safe fallback.
477+
MAX_BATCH_SIZE = 256
478+
MAX_INPUT_LEN = 128000
479+
MAX_OUTPUT_LEN = 1024
480+
481+
EVALUATOR_CLS = tensorrt_llm.evaluate.LongBenchV1
482+
EVALUATOR_KWARGS = dict(dataset_path=DATASET_DIR,
483+
random_seed=0,
484+
apply_chat_template=True)
485+
486+
459487
class CliFlowAccuracyTestHarness:
460488
# Model
461489
MODEL_NAME = None
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
Qwen3/Qwen3-30B-A3B-Instruct-2507:
2+
# Skip Softmax Attention ref accuracy
3+
- extra_acc_spec: "target_sparsity=0.0"
4+
accuracy: 47.22
5+
- extra_acc_spec: "target_sparsity=0.5"
6+
accuracy: 47.22
7+
- extra_acc_spec: "target_sparsity=0.9"
8+
accuracy: 45.90

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def patched_start_mpi_pool(self):
5555
EagleDecodingConfig, KvCacheConfig, MoeConfig,
5656
MTPDecodingConfig, NGramDecodingConfig,
5757
RocketSparseAttentionConfig, SamplingParams,
58-
TorchCompileConfig)
58+
SkipSoftmaxAttentionConfig, TorchCompileConfig)
5959
from tensorrt_llm.quantization import QuantAlgo
6060

6161
from ..conftest import (get_device_count, get_device_memory, llm_models_root,
@@ -64,7 +64,7 @@ def patched_start_mpi_pool(self):
6464
skip_pre_hopper, skip_ray)
6565
from .accuracy_core import (GSM8K, MMLU, CnnDailymail, GPQADiamond,
6666
JsonModeEval, LlmapiAccuracyTestHarness,
67-
LongBenchV2)
67+
LongBenchV1, LongBenchV2)
6868

6969

7070
def _get_default_torch_compile_config(torch_compile):
@@ -3816,6 +3816,46 @@ def test_nvfp4_4gpus(self, tp_size, pp_size, ep_size, attention_dp,
38163816
task.evaluate(llm)
38173817

38183818

3819+
class TestQwen3_30B_A3B_Instruct_2507(LlmapiAccuracyTestHarness):
3820+
MODEL_NAME = "Qwen3/Qwen3-30B-A3B-Instruct-2507"
3821+
MODEL_PATH = f"{llm_models_root()}/{MODEL_NAME}"
3822+
3823+
@skip_pre_hopper
3824+
# @pytest.mark.skip_less_device_memory(140000) # Only test for H200, B200
3825+
@pytest.mark.parametrize(
3826+
"target_sparsity,thr_prefill,thr_decode",
3827+
[
3828+
(0.0, 0.0, 0.0),
3829+
(0.5, 85.97384174442398, 55.48258322852407),
3830+
(0.9, 1418.142868970396, 863.147841750025),
3831+
],
3832+
ids=[
3833+
"target_sparsity_0.0", "target_sparsity_0.5", "target_sparsity_0.9"
3834+
],
3835+
)
3836+
def test_skip_softmax_attention(self, target_sparsity: float,
3837+
thr_prefill: float, thr_decode: float):
3838+
sparse_attention_config = SkipSoftmaxAttentionConfig(
3839+
threshold_scale_factor={
3840+
"prefill": thr_prefill,
3841+
"decode": thr_decode,
3842+
})
3843+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.85)
3844+
3845+
if get_sm_version() >= 100:
3846+
pytest.skip("Bug to be fixed on Blackwell")
3847+
3848+
with LLM(self.MODEL_PATH,
3849+
attn_backend="TRTLLM",
3850+
max_batch_size=256,
3851+
max_num_tokens=100000,
3852+
kv_cache_config=kv_cache_config,
3853+
sparse_attention_config=sparse_attention_config) as llm:
3854+
task = LongBenchV1(self.MODEL_NAME)
3855+
task.evaluate(llm,
3856+
extra_acc_spec=f"target_sparsity={target_sparsity}")
3857+
3858+
38193859
class TestPhi4MiniInstruct(LlmapiAccuracyTestHarness):
38203860
MODEL_NAME = "microsoft/Phi-4-mini-instruct"
38213861
MODEL_PATH = f"{llm_models_root()}/Phi-4-mini-instruct"

tests/integration/test_lists/test-db/l0_b200.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ l0_b200:
5555
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[mxfp8-latency-TRTLLM]
5656
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[mxfp8-latency-CUTLASS]
5757
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a16_mxfp4[latency-TRTLLM]
58+
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.0]
59+
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.5]
60+
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.9]
5861
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp1-cutlass]
5962
- disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyLlama-1.1B-Chat-v1.0] # nvbugs 5300551
6063
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-8B]

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ l0_h100:
7676
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency-torch_compile=False]
7777
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency-torch_compile=True]
7878
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_dummy_load_format
79+
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.0] TIMEOUT (90)
80+
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.5] TIMEOUT (90)
81+
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.9] TIMEOUT (90)
7982
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[enable_chunked_prefill=False-eagle3_one_model=False]
8083
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[enable_chunked_prefill=True-eagle3_one_model=True]
8184
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[enable_chunked_prefill=False-eagle3_one_model=True]

0 commit comments

Comments
 (0)