@@ -23,6 +23,8 @@ using namespace tensorrt_llm::testing;
2323using DisaggParamsType = std::tuple<int , std::vector<std::string>, std::vector<std::vector<int >>,
2424 std::vector<std::vector<int >>, std::vector<int >, int >;
2525
26+ using CondDisaggParamsType = std::tuple<std::string>;
27+
2628enum InstanceRole : int
2729{
2830 CONTEXT = 1 ,
@@ -83,6 +85,12 @@ std::string generateTestNameDisaggParams(testing::TestParamInfo<DisaggParamsType
8385 return name;
8486}
8587
88+ std::string generateTestNameCondDisaggParams (testing::TestParamInfo<CondDisaggParamsType> const & info)
89+ {
90+ auto const modelName = std::get<0 >(info.param );
91+ return " Model_" + modelName;
92+ }
93+
8694class DisaggParamsTest : public GptExecutorTest , public ::testing::WithParamInterface<DisaggParamsType>
8795{
8896};
@@ -91,6 +99,10 @@ class DisaggOrchestratorParamsTest : public GptExecutorTest, public ::testing::W
9199{
92100};
93101
102+ class ConditionalDisaggParamsTest : public GptExecutorTest , public ::testing::WithParamInterface<CondDisaggParamsType>
103+ {
104+ };
105+
94106namespace
95107{
96108void verifyGenerateDistStats (std::deque<RequestStatsPerIteration> const & iterationStats)
@@ -833,6 +845,202 @@ TEST_P(DisaggOrchestratorParamsTest, DisaggTokenComparison)
833845#endif
834846}
835847
848+ TEST_P (ConditionalDisaggParamsTest, DisaggTokenComparison)
849+ {
850+ #if ENABLE_MULTI_DEVICE
851+ if (!tensorrt_llm::common::getEnvUseUCXKvCache ())
852+ {
853+ setenv (" UCX_TLS" , " ^cuda_ipc" , 1 ); // disable cuda_ipc for testing for mpi
854+ }
855+ auto constexpr processNum = 2 ;
856+ auto const & modelName = std::get<0 >(GetParam ());
857+ auto constexpr controllerRank = 0 ;
858+
859+ // params_check
860+ auto const & world_comm = tensorrt_llm::mpi::MpiComm::world ();
861+ int const commRank = world_comm.getRank ();
862+ int const commSize = world_comm.getSize ();
863+ if (commSize != processNum)
864+ {
865+ GTEST_SKIP () << " need " << processNum << " processes but got " << commSize << " mpi processes, skip test." ;
866+ }
867+
868+ bool isContext = commRank == 0 ;
869+ bool isGeneration = commRank == 1 ;
870+ std::vector<int > participatntIds = {commRank};
871+ std::vector<int > deviceIds = {commRank};
872+ bool isController = (commRank == controllerRank);
873+
874+ OutputConfig outConfig (false , false , false , false , false , false );
875+ int const beamWidth = 1 ;
876+ BeamResult beamResult{beamWidth};
877+
878+ bool streaming = false ;
879+ int const maxBeamWidth = 1 ;
880+ ASSERT_TRUE (fs::exists (DATA_PATH));
881+
882+ fs::path modelPath;
883+ // set defaults and adjust if needed by different models
884+ fs::path inputPath = DATA_PATH / " input_tokens.npy" ;
885+ ModelIds modelIds{50256 , 50256 };
886+ bool isSpeculativeDecoding{false };
887+
888+ // NOTE: This can be used to disable checks for certain prompt batch entries
889+ FlakyTestInfo flakyTestInfo;
890+
891+ if (modelName == " gpt" )
892+ {
893+ auto const resultsPath
894+ = GPT_DATA_PATH / ((beamWidth == 1 ) ? " sampling" : " beam_search_" + std::to_string (beamWidth));
895+ modelPath = GPT_MODEL_PATH / PathUtil::FP16_GPT_ATTENTION_PACKED_PAGED_DIR () / " tp1-pp1-cp1-gpu" ;
896+ beamResult.resultsFile = resultsPath / PathUtil::FP16_PLUGIN_PACKED_PAGED_RESULT_FILE ();
897+ }
898+ else if (modelName == " llama_tp1_pp1_cp1" )
899+ {
900+ auto const resultsPath
901+ = LLAMA_DATA_PATH / ((beamWidth == 1 ) ? " sampling" : " beam_search_" + std::to_string (beamWidth));
902+ modelIds.padId = 2 ;
903+ modelIds.endId = 2 ;
904+ beamResult.resultsFile = resultsPath / PathUtil::FP16_PLUGIN_PACKED_PAGED_RESULT_TP1_PP1_FILE ();
905+ modelPath = LLAMA_MODEL_PATH / PathUtil::FP16_GPT_ATTENTION_PACKED_PAGED_DIR () / " tp1-pp1-cp1-gpu" ;
906+ }
907+ else
908+ {
909+ TLLM_THROW (" Unrecognized modelName" );
910+ }
911+
912+ SizeType32 constexpr vocabSizePadded{50257 }; // gpt vocabSizePadded
913+
914+ auto executorConfig = ExecutorConfig (maxBeamWidth);
915+ FloatType freeGpuMemoryFraction = 0 .9f ;
916+ KvCacheConfig kvCacheConfig{true , std::nullopt , std::nullopt , std::nullopt , freeGpuMemoryFraction};
917+ executorConfig.setKvCacheConfig (kvCacheConfig);
918+ executorConfig.setRequestStatsMaxIterations (1000 );
919+ auto manager = tr::BufferManager (std::make_shared<tr::CudaStream>());
920+ auto const & givenInput = tr::utils::loadNpy (manager, inputPath.string (), tr::MemoryType::kCPU );
921+ auto [givenInputLengths, nbGivenInputs, maxInputLength] = getGivenInputLengths (*givenInput, modelIds.padId );
922+ world_comm.barrier ();
923+ auto executor = tensorrt_llm::testing::disaggexecutor::DisaggExecutorLeader (modelPath, ModelType::kDECODER_ONLY ,
924+ executorConfig, isController, isContext, isGeneration, givenInputLengths.size (), participatntIds, deviceIds,
925+ commRank);
926+
927+ std::unordered_map<IdType, SizeType32> reqIdToBatchId;
928+ std::unordered_map<SizeType32, std::vector<BeamTokens>> tokens;
929+ auto const * const givenInputData = tr::bufferCast<TokenIdType const >(*givenInput);
930+
931+ auto const & inputShape = givenInput->getShape ();
932+ ASSERT_EQ (inputShape.nbDims , 2 );
933+ ASSERT_GT (inputShape.d [0 ], 0 );
934+
935+ // Load expected outputs for each beam width value
936+ auto testData = TestData::loadTestData (beamResult, *givenInput, beamWidth, manager, outConfig, modelIds);
937+ auto const maxSeqLen = testData.maxSeqLen ;
938+
939+ // Load expected outputs and inputs
940+ SizeType32 numRequests = static_cast <SizeType32>(givenInputLengths.size ());
941+ SizeType32 maxRequests = numRequests;
942+ std::vector<Request> requests;
943+ std::vector<SizeType32> reqMaxNewTokens;
944+ SizeType32 const numReturnSequences = 1 ;
945+
946+ for (SizeType32 req = 0 ; req < maxRequests; ++req)
947+ {
948+ SizeType32 inputLen = givenInputLengths.at (req);
949+ auto maxNewTokens = maxSeqLen - maxInputLength;
950+ reqMaxNewTokens.push_back (maxNewTokens);
951+ SizeType32 endId = -1 ;
952+ auto const * const seqBegin = givenInputData + req * maxInputLength;
953+ VecTokens tokens (seqBegin, seqBegin + inputLen);
954+ auto samplingConfig = tensorrt_llm::executor::SamplingConfig (beamWidth);
955+ samplingConfig.setNumReturnSequences (numReturnSequences);
956+ auto request = Request (
957+ VecTokens (seqBegin, seqBegin + inputLen), maxNewTokens, streaming, samplingConfig, outConfig, endId);
958+ request.setReturnAllGeneratedTokens (false );
959+ // setting request type to context/full by condition
960+ if (req % 2 == 0 )
961+ {
962+ request.setRequestType (RequestType::REQUEST_TYPE_CONTEXT_ONLY);
963+ }
964+ else
965+ {
966+ request.setRequestType (RequestType::REQUEST_TYPE_CONTEXT_AND_GENERATION);
967+ }
968+ requests.emplace_back (std::move (request));
969+ }
970+
971+ if (isController)
972+ {
973+ std::vector<IdType> reqIds;
974+
975+ for (int i = 0 ; i < requests.size (); ++i)
976+ {
977+ std::vector<BeamTokens> resultTokens;
978+ resultTokens.reserve (numReturnSequences);
979+ for (SizeType32 seqIdx = 0 ; seqIdx < numReturnSequences; ++seqIdx)
980+ {
981+ resultTokens.emplace_back (beamWidth);
982+ }
983+ auto retReqId = executor.enqueueRequests ({requests[i]});
984+ reqIds.push_back (retReqId.front ());
985+ tokens[i] = std::move (resultTokens);
986+ reqIdToBatchId[retReqId.front ()] = i;
987+ }
988+
989+ // Get the new tokens for each requests
990+ int32_t numFinished = 0 ;
991+ int iter = 0 ;
992+ SizeType32 numResponses = 0 ;
993+ while (numFinished < maxRequests && iter < mMaxWaitMs )
994+ {
995+ std::chrono::milliseconds waitTime (1 );
996+ auto responses = executor.awaitResponses (waitTime);
997+ for (auto & response : responses)
998+ {
999+ numResponses++;
1000+ if (!response.hasError ())
1001+ {
1002+ auto result = response.getResult ();
1003+ numFinished += result.isFinal ;
1004+ auto batchId = reqIdToBatchId.at (response.getRequestId ());
1005+ auto seqIdx = result.sequenceIndex ;
1006+
1007+ auto & outputTokenIds = result.outputTokenIds ;
1008+
1009+ EXPECT_EQ (result.finishReasons .size (), beamWidth);
1010+ for (SizeType32 beam = 0 ; beam < beamWidth; ++beam)
1011+ {
1012+ auto & newTokens = outputTokenIds.at (beam);
1013+ auto & reqTokens = tokens.at (batchId).at (seqIdx).at (beam);
1014+
1015+ reqTokens.insert (reqTokens.end (), newTokens.begin (), newTokens.end ());
1016+ // FinishReason is only supported for bw=1 and inflight batching.
1017+ if (beamWidth == 1 && executorConfig.getBatchingType () == BatchingType::kINFLIGHT )
1018+ {
1019+ EXPECT_EQ (result.finishReasons .at (beam),
1020+ result.isFinal ? FinishReason::kLENGTH : FinishReason::kNOT_FINISHED );
1021+ }
1022+ }
1023+ }
1024+ else
1025+ {
1026+ // Allow response with error only if awaitResponse processed a terminated request id
1027+ std::string err = " ReqId " + std::to_string (response.getRequestId ())
1028+ + " has already been processed and was terminated." ;
1029+ EXPECT_EQ (response.getErrorMsg (), err);
1030+ }
1031+ }
1032+ ++iter;
1033+ }
1034+ EXPECT_LT (iter, mMaxWaitMs );
1035+ testData.verifyOutput (tokens, givenInputLengths, nbGivenInputs, streaming, outConfig.excludeInputFromOutput ,
1036+ flakyTestInfo, isSpeculativeDecoding, false , beamWidth, numReturnSequences, false );
1037+ }
1038+ world_comm.barrier ();
1039+ #else
1040+ GTEST_SKIP () << " Skipping DisaggExecutor Test" ;
1041+ #endif
1042+ }
1043+
8361044INSTANTIATE_TEST_SUITE_P (GptDisaggSymmetricExecutorTest, DisaggParamsTest,
8371045 testing::Combine (testing::Values(2 ), testing::Values(std::vector<std::string>{" gpt" , " gpt" }),
8381046 testing::Values(std::vector<std::vector<int >>{{0 }, {1 }}),
@@ -868,6 +1076,9 @@ INSTANTIATE_TEST_SUITE_P(GptSingleDeviceDisaggSymmetricExecutorMixedTest, Disagg
8681076 testing::Values(1 )),
8691077 generateTestNameDisaggParams);
8701078
1079+ INSTANTIATE_TEST_SUITE_P (ConditionalDisaggExecutorTest, ConditionalDisaggParamsTest,
1080+ testing::Combine (testing::Values(" gpt" , " llama_tp1_pp1_cp1" )), generateTestNameCondDisaggParams);
1081+
8711082INSTANTIATE_TEST_SUITE_P (LlamaTP2DisaggSymmetricExecutorTest, DisaggParamsTest,
8721083 testing::Combine (testing::Values(4 ),
8731084 testing::Values(std::vector<std::string>{" llama_tp2_pp1_cp1" , " llama_tp2_pp1_cp1" }),
0 commit comments