Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ jieba==0.42.1
rouge==1.0.1
pytest-rerunfailures
ruff==0.9.4
lm_eval[api]==0.4.8
lm_eval[api]==0.4.9.2
docstring_parser
genai-perf==0.0.13
opentelemetry-sdk>=1.26.0
opentelemetry-api>=1.26.0
opentelemetry-exporter-otlp>=1.26.0
opentelemetry-semantic-conventions-ai>=0.4.1
fuzzywuzzy==0.18.0
aiperf==0.3.0
4 changes: 3 additions & 1 deletion tensorrt_llm/commands/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from .. import LLM as PyTorchLLM
from .._tensorrt_engine import LLM
from ..evaluate import (GSM8K, MMLU, MMMU, CnnDailymail, GPQADiamond,
GPQAExtended, GPQAMain, JsonModeEval, LongBenchV2)
GPQAExtended, GPQAMain, JsonModeEval, LongBenchV1,
LongBenchV2)
from ..llmapi import BuildConfig, KvCacheConfig
from ..llmapi.llm_utils import update_llm_args_with_extra_options
from ..logger import logger, severity_map
Expand Down Expand Up @@ -181,6 +182,7 @@ def main(ctx, model: str, tokenizer: Optional[str],
main.add_command(GPQAExtended.command)
main.add_command(JsonModeEval.command)
main.add_command(MMMU.command)
main.add_command(LongBenchV1.command)
main.add_command(LongBenchV2.command)

if __name__ == "__main__":
Expand Down
5 changes: 3 additions & 2 deletions tensorrt_llm/evaluate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@

from .cnn_dailymail import CnnDailymail
from .json_mode_eval import JsonModeEval
from .lm_eval import GSM8K, MMMU, GPQADiamond, GPQAExtended, GPQAMain
from .lm_eval import (GSM8K, MMMU, GPQADiamond, GPQAExtended, GPQAMain,
LongBenchV1)
from .longbench_v2 import LongBenchV2
from .mmlu import MMLU

__all__ = [
"CnnDailymail", "MMLU", "GSM8K", "GPQADiamond", "GPQAMain", "GPQAExtended",
"JsonModeEval", "MMMU", "LongBenchV2"
"JsonModeEval", "MMMU", "LongBenchV1", "LongBenchV2"
]
168 changes: 167 additions & 1 deletion tensorrt_llm/evaluate/lm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,23 @@ def _get_sampling_params(self, gen_kwargs: dict) -> SamplingParams:
"max_gen_toks": "max_tokens",
"until": "stop",
}
# IMPORTANT:
# lm-evaluation-harness controls generation primarily via per-task gen_kwargs.
# For example, the `local-completions` model wrapper uses:
# max_tokens <- gen_kwargs["max_tokens"] or gen_kwargs["max_gen_toks"] or _max_gen_toks
# temperature <- gen_kwargs.get("temperature", 0)
# stop <- gen_kwargs.get("until", ...)
# See: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/models/openai_completions.py

if self.sampling_params is None:
sampling_params = SamplingParams()
sampling_params = SamplingParams(
max_tokens=gen_kwargs.get("max_gen_toks", 256),
temperature=gen_kwargs.get("temperature", 0),
stop=gen_kwargs.get("until", None),
)
else:
sampling_params = copy.deepcopy(self.sampling_params)

for lm_eval_key, trtllm_key in params_mapping.items():
value = gen_kwargs.pop(lm_eval_key, None)
if value is not None:
Expand Down Expand Up @@ -714,3 +727,156 @@ def command(ctx, **kwargs) -> None:
kwargs[
"stop"] = "<|endoftext|>" # NOTE: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/mmmu/_template_yaml#L10
MMMU.command_harness(ctx, **kwargs)


class LongBenchV1(LmEvalEvaluator):
"""
LongBench v1 evaluation via lm-evaluation-harness.

Notes:
- In lm-eval, `longbench` is typically a *group task* that expands into many
subtasks. The base `LmEvalEvaluator.evaluate()` assumes a single task
key exists in `results["results"][task_name]`, so we override evaluation
to aggregate over subtasks.
"""

def __init__(self, **kwargs):
super().__init__("longbench", **kwargs)

@staticmethod
def _flatten_task_dict(task_dict: dict) -> List[str]:
names: List[str] = []
for k, v in task_dict.items():
if isinstance(v, dict):
names.extend(LongBenchV1._flatten_task_dict(v))
else:
names.append(k)
return names

@staticmethod
def _get_group_score(metrics: Dict[str, Any],
*,
preferred_filter: str = "none") -> Optional[float]:
"""
lm-eval stores group metrics as "<metric>,<filter>" (e.g., "score,none").
Prefer "score,none" (matches printed table), otherwise accept any
"score,<filter>" key.
"""
if not isinstance(metrics, dict):
return None

preferred_key = f"score,{preferred_filter}"
v = metrics.get(preferred_key, None)
if isinstance(v, (int, float)):
return float(v)

return None

def evaluate(self,
llm: Union[LLM, PyTorchLLM],
sampling_params: Optional[SamplingParams] = None,
streaming: bool = False) -> float:
import lm_eval

lm_cls = MultimodalLmEvalWrapper if self.MULTIMODAL else LmEvalWrapper
results = lm_eval.evaluate(
lm=lm_cls(llm,
sampling_params=sampling_params,
streaming=streaming,
chat_template_kwargs=self.chat_template_kwargs),
task_dict=self.task_dict,
limit=self.num_samples,
apply_chat_template=self.apply_chat_template,
fewshot_as_multiturn=self.fewshot_as_multiturn,
system_instruction=self.system_prompt)

logger.info(
f"lm-eval {self.task_name} results:\n{lm_eval.utils.make_table(results)}"
)

# LongBench is a group task in lm-eval. lm-eval already computes subgroup
# "score" values (e.g., `longbench_fewshot`, `longbench_single`, ...).
# To keep this implementation simple and aligned with the printed table,
# we compute the final LongBench score as the unweighted mean of subgroup
# scores.
group_results: Dict[str, Dict[str, Any]] = results.get("groups", {})
subgroup_names = results.get("group_subtasks",
{}).get(self.task_name, [])
if not subgroup_names:
raise KeyError(
f"lm-eval did not provide subgroup list for group '{self.task_name}'. "
"Expected `results['group_subtasks'][task_name]` to exist.")

subgroup_scores: List[float] = []
missing: List[str] = []
for name in subgroup_names:
m = group_results.get(name, None)
score = self._get_group_score(m)
if score is None:
missing.append(name)
else:
subgroup_scores.append(score)

if not subgroup_scores:
raise KeyError(
f"lm-eval did not provide subgroup 'score' metrics for '{self.task_name}'. "
f"Missing subgroups: {missing[:10]}")

result_acc = float(np.mean(subgroup_scores)) * 100
logger.info(
f"lm-eval {self.task_name} average 'score' across {len(subgroup_scores)} subgroups: {result_acc:.2f}"
)
return result_acc

@click.command("longbench_v1")
@click.option(
"--dataset_path",
type=str,
default=None,
help=
"The path to LongBench dataset. If unspecified, the dataset is downloaded from HF hub."
)
@click.option(
"--num_samples",
type=int,
default=None,
help="Number of samples to run the evaluation; None means full dataset."
)
@click.option("--random_seed",
type=int,
default=0,
help="Random seed for dataset processing.")
@click.option("--apply_chat_template",
type=click.BOOL,
default=True,
show_default=True,
help="Whether to apply chat template.")
@click.option(
"--chat_template_kwargs",
type=str,
default=None,
callback=lambda ctx, param, value: json.loads(value) if value else None,
help=
'Chat template kwargs as JSON string, e.g., \'{"thinking_budget": 0}\'')
@click.option("--system_prompt",
type=str,
default=None,
help="System prompt.")
@click.pass_context
@staticmethod
def command(ctx, **kwargs) -> None:
llm: Union[LLM, PyTorchLLM] = ctx.obj

evaluator = LongBenchV1(
dataset_path=kwargs.pop("dataset_path", None),
num_samples=kwargs.pop("num_samples", None),
random_seed=kwargs.pop("random_seed", 0),
apply_chat_template=kwargs.pop("apply_chat_template", True),
system_prompt=kwargs.pop("system_prompt", None),
chat_template_kwargs=kwargs.pop("chat_template_kwargs", None))

# Let lm-eval task configs control sampling via gen_kwargs.
sampling_params = None

evaluator.evaluate(llm, sampling_params)
llm.shutdown()
28 changes: 28 additions & 0 deletions tests/integration/defs/accuracy/accuracy_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,34 @@ class LongBenchV2(AccuracyTask):
)


class LongBenchV1(AccuracyTask):
DATASET = "longbench_v1"
# Keep the dataset local like other accuracy tasks (avoid HF hub traffic).
# Expected to be populated in CI image / test environment.
DATASET_DIR = f"{llm_models_root()}/datasets/Xnhyacinth/LongBench"

# NOTE: LongBench v1 is driven by lm-evaluation-harness task configs.
# We intentionally do not pin dataset_path here (it can be resolved by lm-eval
# via HF Hub or local cache).
ALPHA = 0.05
BETA = 0.2
SIGMA = 50.0

# Full sample
NUM_SAMPLES = 4750

# These are used by AccuracyTask to construct SamplingParams defaults.
# LongBench v1 tasks provide per-task gen_kwargs, so these are mainly a safe fallback.
MAX_BATCH_SIZE = 256
MAX_INPUT_LEN = 128000
MAX_OUTPUT_LEN = 1024

EVALUATOR_CLS = tensorrt_llm.evaluate.LongBenchV1
EVALUATOR_KWARGS = dict(dataset_path=DATASET_DIR,
random_seed=0,
apply_chat_template=True)


class CliFlowAccuracyTestHarness:
# Model
MODEL_NAME = None
Expand Down
8 changes: 8 additions & 0 deletions tests/integration/defs/accuracy/references/longbench_v1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Qwen3/Qwen3-30B-A3B-Instruct-2507:
# Skip Softmax Attention ref accuracy
- extra_acc_spec: "target_sparsity=0.0"
accuracy: 47.22
- extra_acc_spec: "target_sparsity=0.5"
accuracy: 47.22
- extra_acc_spec: "target_sparsity=0.9"
accuracy: 45.90
44 changes: 42 additions & 2 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def patched_start_mpi_pool(self):
EagleDecodingConfig, KvCacheConfig, MoeConfig,
MTPDecodingConfig, NGramDecodingConfig,
RocketSparseAttentionConfig, SamplingParams,
TorchCompileConfig)
SkipSoftmaxAttentionConfig, TorchCompileConfig)
from tensorrt_llm.quantization import QuantAlgo

from ..conftest import (get_device_count, get_device_memory, llm_models_root,
Expand All @@ -64,7 +64,7 @@ def patched_start_mpi_pool(self):
skip_pre_hopper, skip_ray)
from .accuracy_core import (GSM8K, MMLU, CnnDailymail, GPQADiamond,
JsonModeEval, LlmapiAccuracyTestHarness,
LongBenchV2)
LongBenchV1, LongBenchV2)


class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
Expand Down Expand Up @@ -3866,6 +3866,46 @@ def test_nvfp4_4gpus(self, tp_size, pp_size, ep_size, attention_dp,
task.evaluate(llm)


class TestQwen3_30B_A3B_Instruct_2507(LlmapiAccuracyTestHarness):
MODEL_NAME = "Qwen3/Qwen3-30B-A3B-Instruct-2507"
MODEL_PATH = f"{llm_models_root()}/{MODEL_NAME}"

@skip_pre_hopper
# @pytest.mark.skip_less_device_memory(140000) # Only test for H200, B200
@pytest.mark.parametrize(
"target_sparsity,thr_prefill,thr_decode",
[
(0.0, 0.0, 0.0),
(0.5, 85.97384174442398, 55.48258322852407),
(0.9, 1418.142868970396, 863.147841750025),
],
ids=[
"target_sparsity_0.0", "target_sparsity_0.5", "target_sparsity_0.9"
],
)
def test_skip_softmax_attention(self, target_sparsity: float,
thr_prefill: float, thr_decode: float):
sparse_attention_config = SkipSoftmaxAttentionConfig(
threshold_scale_factor={
"prefill": thr_prefill,
"decode": thr_decode,
})
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.85)

if get_sm_version() >= 100:
pytest.skip("Bug to be fixed on Blackwell")

with LLM(self.MODEL_PATH,
attn_backend="TRTLLM",
max_batch_size=256,
max_num_tokens=100000,
kv_cache_config=kv_cache_config,
sparse_attention_config=sparse_attention_config) as llm:
task = LongBenchV1(self.MODEL_NAME)
task.evaluate(llm,
extra_acc_spec=f"target_sparsity={target_sparsity}")


class TestPhi4MiniInstruct(LlmapiAccuracyTestHarness):
MODEL_NAME = "microsoft/Phi-4-mini-instruct"
MODEL_PATH = f"{llm_models_root()}/Phi-4-mini-instruct"
Expand Down
3 changes: 3 additions & 0 deletions tests/integration/test_lists/test-db/l0_b200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ l0_b200:
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[mxfp8-latency-TRTLLM]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[mxfp8-latency-CUTLASS]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a16_mxfp4[latency-TRTLLM]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.0]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.5]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.9]
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp1-cutlass]
- disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyLlama-1.1B-Chat-v1.0] # nvbugs 5300551
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-8B]
Expand Down
3 changes: 3 additions & 0 deletions tests/integration/test_lists/test-db/l0_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ l0_h100:
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_dummy_load_format
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.0] TIMEOUT (90)
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.5] TIMEOUT (90)
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.9] TIMEOUT (90)
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[enable_chunked_prefill=False-eagle3_one_model=False]
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[enable_chunked_prefill=True-eagle3_one_model=True]
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[enable_chunked_prefill=False-eagle3_one_model=True]
Expand Down
Loading