Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 202 additions & 1 deletion tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@
# limitations under the License.
import os
import sys
from tensorrt_llm.llmapi.llm import RequestOutput
import time

import pytest
import torch
from datasets import load_dataset
from mpi4py.futures import MPIPoolExecutor

import asyncio

def patch_mpi_pool_session_for_env(mocker, env_vars: dict):
"""
Expand Down Expand Up @@ -3021,6 +3024,204 @@ def test_nvfp4(self, tp_size):
task.evaluate(llm)
task = GSM8K(model_name)
task.evaluate(llm)

@skip_pre_blackwell
@pytest.mark.skip_less_device(8)
@pytest.mark.skip_less_device_memory(120000)
@pytest.mark.timeout(14400)
def test_nvfp4_longseq_trtllm_moe(self, mocker):
"""
Long-sequence MoE regression test with PDL enabled.

Tests long-sequence generation (250K tokens) with TRTLLM MoE backend and PDL.
Validates that the model produces valid outputs (non-empty, non-zero tokens)
across multiple batches with both greedy and sampling strategies.
Also includes async streaming with cancellation test.
Finally verifies accuracy with MMLU and GSM8K benchmarks.
More info in https://nvbugspro.nvidia.com/bug/5661741
"""
patch_mpi_pool_session_for_env(mocker, {"TRTLLM_ENABLE_PDL": "1"})

model_name = "moonshotai/Kimi-K2-Thinking"
model_path = f"{llm_models_root()}/Kimi-K2-Thinking-NVFP4"

kv_cache_config = KvCacheConfig(
dtype="fp8",
free_gpu_memory_fraction=0.9,
enable_block_reuse=True,
enable_partial_reuse=False,
event_buffer_max_size=1024,
)

target_len = 250000

with LLM(
model_path,
tensor_parallel_size=8,
moe_expert_parallel_size=4,
moe_config=MoeConfig(backend="TRTLLM"),
enable_chunked_prefill=True,
trust_remote_code=True,
kv_cache_config=kv_cache_config,
max_num_tokens=8192,
max_seq_len=262144,
max_batch_size=32,
enable_attention_dp=True,
) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4

tokenizer = llm.tokenizer
dataset = load_dataset(
"Crystalcareai/Code-feedback-sharegpt-renamed",
split="train[:2000]"
)

long_token_list = []
for row in dataset:
msg = row["messages"][0]["value"]
tokens = tokenizer.encode(msg, add_special_tokens=False)
if not tokens:
continue
repeat = target_len // len(tokens) + 1
long_tokens = (tokens * repeat)[:target_len]
long_token_list.append(long_tokens)

assert len(
long_token_list
) > 0, "No valid samples found to build long sequences"

batch_size = 10
sampling_params_greedy = SamplingParams(max_tokens=8)
sampling_params_sampling = SamplingParams(max_tokens=8,
temperature=0.8,
top_p=0.95)

max_duration_sec = 3 * 3600
max_batches = 200
start_time = time.time()
num_samples = len(long_token_list)

for batch_idx in range(max_batches):
if (time.time() - start_time) >= max_duration_sec:
break
start_idx = (batch_idx * batch_size) % num_samples

indices = [
(start_idx + i) % num_samples
for i in range(batch_size)
]
batch_inputs = [long_token_list[i] for i in indices]

outputs = llm.generate(batch_inputs,
sampling_params=sampling_params_greedy)

for i, output in enumerate(outputs):
token_ids = output.outputs[0].token_ids
text = output.outputs[0].text

assert len(token_ids) > 0, (
f"[greedy] Empty output. "
f"Batch {batch_idx}, Request {i}, "
f"indices={indices[i]}, text_snippet={text[:128]!r}"
)

assert not all(tid == 0 for tid in token_ids), (
f"[greedy] All token IDs are 0. "
f"Batch {batch_idx}, Request {i}, "
f"indices={indices[i]}, text_snippet={text[:128]!r}"
)

outputs_sampling = llm.generate(
batch_inputs, sampling_params=sampling_params_sampling)

for i, output in enumerate(outputs_sampling):
token_ids = output.outputs[0].token_ids
text = output.outputs[0].text

assert len(token_ids) > 0, (
f"[sampling] Empty output. "
f"Batch {batch_idx}, Request {i}, "
f"indices={indices[i]}, text_snippet={text[:128]!r}"
)

assert not all(tid == 0 for tid in token_ids), (
f"[sampling] All token IDs are 0. "
f"Batch {batch_idx}, Request {i}, "
f"indices={indices[i]}, text_snippet={text[:128]!r}")

print("\n[Async Streaming Test] Starting async streaming cancellation test...")

async_batch_size = 8
num_async_batches = 5
cancel_ratio = 0.5

async def run_streaming_with_cancellation():
for async_batch_idx in range(num_async_batches):
start_idx = (async_batch_idx * async_batch_size) % num_samples
indices = [(start_idx + i) % num_samples for i in range(async_batch_size)]
batch_inputs = [long_token_list[i] for i in indices]

sampling_params_streaming = SamplingParams(
max_tokens=50,
temperature=0.8,
top_p=0.95
)

# Generate async streaming results
async_results = llm.generate_async(
batch_inputs,
sampling_params=sampling_params_streaming,
streaming=True
)

for req_idx, async_gen in enumerate(async_results):
chunks_received = 0
max_chunks_before_cancel = 5
should_cancel = (req_idx < async_batch_size * cancel_ratio)

try:
async for chunk in async_gen:
chunks_received += 1

# Validate chunk
if chunk.outputs:
token_ids = chunk.outputs[0].token_ids
text = chunk.outputs[0].text

assert len(token_ids) > 0, (
f"[async-streaming] Empty chunk. "
f"AsyncBatch {async_batch_idx}, Request {req_idx}, "
f"Chunk {chunks_received}"
)

assert not all(tid == 0 for tid in token_ids), (
f"[async-streaming] All tokens are 0. "
f"AsyncBatch {async_batch_idx}, Request {req_idx}, "
f"Chunk {chunks_received}"
)

# Simulate cancellation for some requests
if should_cancel and chunks_received >= max_chunks_before_cancel:
print(f" [Cancel] Request {req_idx} after {chunks_received} chunks")
break
except Exception as e:
# Log but don't fail on cancellation-related errors
print(f" [Warning] Request {req_idx} exception: {e}")
if not should_cancel:
raise

print(f" [Async Streaming] Completed batch {async_batch_idx + 1}/{num_async_batches}")

# Run async streaming test
asyncio.run(run_streaming_with_cancellation())
print("[Async Streaming Test] Completed successfully")

# Accuracy eval
task = MMLU(model_name)
task.evaluate(llm)

task = GSM8K(model_name)
task.evaluate(llm)


class TestMinitron4BBaseInstruct(LlmapiAccuracyTestHarness):
Expand Down
Loading