diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 1a32e333b5a..e98abffe065 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -14,11 +14,14 @@ # limitations under the License. import os import sys +from tensorrt_llm.llmapi.llm import RequestOutput +import time import pytest import torch +from datasets import load_dataset from mpi4py.futures import MPIPoolExecutor - +import asyncio def patch_mpi_pool_session_for_env(mocker, env_vars: dict): """ @@ -3021,6 +3024,204 @@ def test_nvfp4(self, tp_size): task.evaluate(llm) task = GSM8K(model_name) task.evaluate(llm) + + @skip_pre_blackwell + @pytest.mark.skip_less_device(8) + @pytest.mark.skip_less_device_memory(120000) + @pytest.mark.timeout(14400) + def test_nvfp4_longseq_trtllm_moe(self, mocker): + """ + Long-sequence MoE regression test with PDL enabled. + + Tests long-sequence generation (250K tokens) with TRTLLM MoE backend and PDL. + Validates that the model produces valid outputs (non-empty, non-zero tokens) + across multiple batches with both greedy and sampling strategies. + Also includes async streaming with cancellation test. + Finally verifies accuracy with MMLU and GSM8K benchmarks. + More info in https://nvbugspro.nvidia.com/bug/5661741 + """ + patch_mpi_pool_session_for_env(mocker, {"TRTLLM_ENABLE_PDL": "1"}) + + model_name = "moonshotai/Kimi-K2-Thinking" + model_path = f"{llm_models_root()}/Kimi-K2-Thinking-NVFP4" + + kv_cache_config = KvCacheConfig( + dtype="fp8", + free_gpu_memory_fraction=0.9, + enable_block_reuse=True, + enable_partial_reuse=False, + event_buffer_max_size=1024, + ) + + target_len = 250000 + + with LLM( + model_path, + tensor_parallel_size=8, + moe_expert_parallel_size=4, + moe_config=MoeConfig(backend="TRTLLM"), + enable_chunked_prefill=True, + trust_remote_code=True, + kv_cache_config=kv_cache_config, + max_num_tokens=8192, + max_seq_len=262144, + max_batch_size=32, + enable_attention_dp=True, + ) as llm: + assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4 + + tokenizer = llm.tokenizer + dataset = load_dataset( + "Crystalcareai/Code-feedback-sharegpt-renamed", + split="train[:2000]" + ) + + long_token_list = [] + for row in dataset: + msg = row["messages"][0]["value"] + tokens = tokenizer.encode(msg, add_special_tokens=False) + if not tokens: + continue + repeat = target_len // len(tokens) + 1 + long_tokens = (tokens * repeat)[:target_len] + long_token_list.append(long_tokens) + + assert len( + long_token_list + ) > 0, "No valid samples found to build long sequences" + + batch_size = 10 + sampling_params_greedy = SamplingParams(max_tokens=8) + sampling_params_sampling = SamplingParams(max_tokens=8, + temperature=0.8, + top_p=0.95) + + max_duration_sec = 3 * 3600 + max_batches = 200 + start_time = time.time() + num_samples = len(long_token_list) + + for batch_idx in range(max_batches): + if (time.time() - start_time) >= max_duration_sec: + break + start_idx = (batch_idx * batch_size) % num_samples + + indices = [ + (start_idx + i) % num_samples + for i in range(batch_size) + ] + batch_inputs = [long_token_list[i] for i in indices] + + outputs = llm.generate(batch_inputs, + sampling_params=sampling_params_greedy) + + for i, output in enumerate(outputs): + token_ids = output.outputs[0].token_ids + text = output.outputs[0].text + + assert len(token_ids) > 0, ( + f"[greedy] Empty output. " + f"Batch {batch_idx}, Request {i}, " + f"indices={indices[i]}, text_snippet={text[:128]!r}" + ) + + assert not all(tid == 0 for tid in token_ids), ( + f"[greedy] All token IDs are 0. " + f"Batch {batch_idx}, Request {i}, " + f"indices={indices[i]}, text_snippet={text[:128]!r}" + ) + + outputs_sampling = llm.generate( + batch_inputs, sampling_params=sampling_params_sampling) + + for i, output in enumerate(outputs_sampling): + token_ids = output.outputs[0].token_ids + text = output.outputs[0].text + + assert len(token_ids) > 0, ( + f"[sampling] Empty output. " + f"Batch {batch_idx}, Request {i}, " + f"indices={indices[i]}, text_snippet={text[:128]!r}" + ) + + assert not all(tid == 0 for tid in token_ids), ( + f"[sampling] All token IDs are 0. " + f"Batch {batch_idx}, Request {i}, " + f"indices={indices[i]}, text_snippet={text[:128]!r}") + + print("\n[Async Streaming Test] Starting async streaming cancellation test...") + + async_batch_size = 8 + num_async_batches = 5 + cancel_ratio = 0.5 + + async def run_streaming_with_cancellation(): + for async_batch_idx in range(num_async_batches): + start_idx = (async_batch_idx * async_batch_size) % num_samples + indices = [(start_idx + i) % num_samples for i in range(async_batch_size)] + batch_inputs = [long_token_list[i] for i in indices] + + sampling_params_streaming = SamplingParams( + max_tokens=50, + temperature=0.8, + top_p=0.95 + ) + + # Generate async streaming results + async_results = llm.generate_async( + batch_inputs, + sampling_params=sampling_params_streaming, + streaming=True + ) + + for req_idx, async_gen in enumerate(async_results): + chunks_received = 0 + max_chunks_before_cancel = 5 + should_cancel = (req_idx < async_batch_size * cancel_ratio) + + try: + async for chunk in async_gen: + chunks_received += 1 + + # Validate chunk + if chunk.outputs: + token_ids = chunk.outputs[0].token_ids + text = chunk.outputs[0].text + + assert len(token_ids) > 0, ( + f"[async-streaming] Empty chunk. " + f"AsyncBatch {async_batch_idx}, Request {req_idx}, " + f"Chunk {chunks_received}" + ) + + assert not all(tid == 0 for tid in token_ids), ( + f"[async-streaming] All tokens are 0. " + f"AsyncBatch {async_batch_idx}, Request {req_idx}, " + f"Chunk {chunks_received}" + ) + + # Simulate cancellation for some requests + if should_cancel and chunks_received >= max_chunks_before_cancel: + print(f" [Cancel] Request {req_idx} after {chunks_received} chunks") + break + except Exception as e: + # Log but don't fail on cancellation-related errors + print(f" [Warning] Request {req_idx} exception: {e}") + if not should_cancel: + raise + + print(f" [Async Streaming] Completed batch {async_batch_idx + 1}/{num_async_batches}") + + # Run async streaming test + asyncio.run(run_streaming_with_cancellation()) + print("[Async Streaming Test] Completed successfully") + + # Accuracy eval + task = MMLU(model_name) + task.evaluate(llm) + + task = GSM8K(model_name) + task.evaluate(llm) class TestMinitron4BBaseInstruct(LlmapiAccuracyTestHarness):