Skip to content

Commit d70ff79

Browse files
authored
conditional disagg test (NVIDIA#3012)
Signed-off-by: Zheng Duan <[email protected]>
1 parent 3e116c9 commit d70ff79

File tree

2 files changed

+251
-14
lines changed

2 files changed

+251
-14
lines changed

cpp/tests/executor/disaggExecutor.h

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,11 @@ enum class MessageID : uint64_t
5656
{
5757
PENDING_CONTEXT_REQUEST = 1,
5858
PENDING_GENERATION_REQUEST = 2,
59-
CONTEXT_RESPONSE = 3,
60-
GENERATION_RESPONSE = 4,
59+
PENDING_FULL_REQUEST = 3,
60+
CONTEXT_RESPONSE = 4,
61+
GENERATION_RESPONSE = 5,
6162

62-
TERMINATION = 5,
63+
TERMINATION = 6,
6364
};
6465

6566
enum DisaggRole : uint32_t
@@ -261,23 +262,37 @@ class DisaggExecutorLeader
261262
}
262263

263264
std::vector<RequestWithId> requestWithIds;
265+
std::vector<RequestWithId> requestWithIdsFull; // full request, not disaggregated
264266
std::vector<IdType> reqIds;
265267
for (auto const& req : llmRequests)
266268
{
267269
IdType id = generatedControlId();
268270
reqIds.push_back(id);
269271

270272
RequestWithId reqWithId{req, id};
271-
reqWithId.req.setRequestType(RequestType::REQUEST_TYPE_CONTEXT_ONLY);
272-
273-
requestWithIds.push_back(std::move(reqWithId));
273+
if (req.getRequestType() == RequestType::REQUEST_TYPE_CONTEXT_ONLY)
274+
{
275+
requestWithIds.push_back(std::move(reqWithId));
276+
}
277+
else
278+
{
279+
TLLM_CHECK(req.getRequestType() == RequestType::REQUEST_TYPE_CONTEXT_AND_GENERATION);
280+
requestWithIdsFull.push_back(std::move(reqWithId));
281+
}
274282

275283
mRequestMap.insert(std::make_pair(id, req));
276284
}
277285

278-
Message message{MessageID::PENDING_CONTEXT_REQUEST, MessageData{RequestsData{requestWithIds}}};
279-
280-
mControllerSendQueue.push(std::move(message));
286+
if (!requestWithIds.empty())
287+
{
288+
Message message{MessageID::PENDING_CONTEXT_REQUEST, MessageData{RequestsData{requestWithIds}}};
289+
mControllerSendQueue.push(std::move(message));
290+
}
291+
if (!requestWithIdsFull.empty())
292+
{
293+
Message message{MessageID::PENDING_FULL_REQUEST, MessageData{RequestsData{requestWithIdsFull}}};
294+
mControllerSendQueue.push(std::move(message));
295+
}
281296

282297
return reqIds;
283298
}
@@ -547,7 +562,8 @@ class DisaggExecutorLeader
547562
mWorldComm.send(packed.data(), packed.size(), tensorrt_llm::mpi::MpiType::kCHAR, contextRank,
548563
kM_CONTROLLER_DATA_TAG);
549564
}
550-
else if (message.id == MessageID::PENDING_GENERATION_REQUEST)
565+
else if (message.id == MessageID::PENDING_GENERATION_REQUEST
566+
|| message.id == MessageID::PENDING_FULL_REQUEST)
551567
{
552568

553569
auto& reqWithIds = std::get<RequestsData>(message.data);
@@ -713,7 +729,8 @@ class DisaggExecutorLeader
713729
shutDown();
714730
break;
715731
}
716-
if (messageId == MessageID::PENDING_CONTEXT_REQUEST || messageId == MessageID::PENDING_GENERATION_REQUEST)
732+
if (messageId == MessageID::PENDING_CONTEXT_REQUEST || messageId == MessageID::PENDING_GENERATION_REQUEST
733+
|| messageId == MessageID::PENDING_FULL_REQUEST)
717734
{
718735
mWorldComm.mprobe(sourceRank, kM_CONTROLLER_DATA_TAG, &msg, &status);
719736
MPICHECK(MPI_Get_count(&status, MPI_CHAR, &count));
@@ -728,13 +745,22 @@ class DisaggExecutorLeader
728745
{
729746
TLLM_CHECK(requestWithId.req.getRequestType() == RequestType::REQUEST_TYPE_CONTEXT_ONLY);
730747
}
731-
else if (isGenerationRank() && messageId == MessageID::PENDING_GENERATION_REQUEST)
748+
else if (isGenerationRank()
749+
&& (messageId == MessageID::PENDING_GENERATION_REQUEST
750+
|| messageId == MessageID::PENDING_FULL_REQUEST))
732751
{
733-
TLLM_CHECK(requestWithId.req.getRequestType() == RequestType::REQUEST_TYPE_GENERATION_ONLY);
752+
if (messageId == MessageID::PENDING_GENERATION_REQUEST)
753+
{
754+
TLLM_CHECK(requestWithId.req.getRequestType() == RequestType::REQUEST_TYPE_GENERATION_ONLY);
755+
}
756+
else // PENDING_FULL_REQUEST
757+
{
758+
TLLM_CHECK(
759+
requestWithId.req.getRequestType() == RequestType::REQUEST_TYPE_CONTEXT_AND_GENERATION);
760+
}
734761
}
735762
else
736763
{
737-
// TODO: support full request (aggregagted)
738764
TLLM_THROW("rank:%d, size:%d InstanceLeaderRecvThread recv Invalid message id:%ld",
739765
mWorldComm.getRank(), mWorldComm.getSize(), static_cast<uint64_t>(messageId));
740766
}

cpp/tests/executor/disaggExecutorTest.cpp

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ using namespace tensorrt_llm::testing;
2323
using DisaggParamsType = std::tuple<int, std::vector<std::string>, std::vector<std::vector<int>>,
2424
std::vector<std::vector<int>>, std::vector<int>, int>;
2525

26+
using CondDisaggParamsType = std::tuple<std::string>;
27+
2628
enum InstanceRole : int
2729
{
2830
CONTEXT = 1,
@@ -83,6 +85,12 @@ std::string generateTestNameDisaggParams(testing::TestParamInfo<DisaggParamsType
8385
return name;
8486
}
8587

88+
std::string generateTestNameCondDisaggParams(testing::TestParamInfo<CondDisaggParamsType> const& info)
89+
{
90+
auto const modelName = std::get<0>(info.param);
91+
return "Model_" + modelName;
92+
}
93+
8694
class DisaggParamsTest : public GptExecutorTest, public ::testing::WithParamInterface<DisaggParamsType>
8795
{
8896
};
@@ -91,6 +99,10 @@ class DisaggOrchestratorParamsTest : public GptExecutorTest, public ::testing::W
9199
{
92100
};
93101

102+
class ConditionalDisaggParamsTest : public GptExecutorTest, public ::testing::WithParamInterface<CondDisaggParamsType>
103+
{
104+
};
105+
94106
namespace
95107
{
96108
void verifyGenerateDistStats(std::deque<RequestStatsPerIteration> const& iterationStats)
@@ -833,6 +845,202 @@ TEST_P(DisaggOrchestratorParamsTest, DisaggTokenComparison)
833845
#endif
834846
}
835847

848+
TEST_P(ConditionalDisaggParamsTest, DisaggTokenComparison)
849+
{
850+
#if ENABLE_MULTI_DEVICE
851+
if (!tensorrt_llm::common::getEnvUseUCXKvCache())
852+
{
853+
setenv("UCX_TLS", "^cuda_ipc", 1); // disable cuda_ipc for testing for mpi
854+
}
855+
auto constexpr processNum = 2;
856+
auto const& modelName = std::get<0>(GetParam());
857+
auto constexpr controllerRank = 0;
858+
859+
// params_check
860+
auto const& world_comm = tensorrt_llm::mpi::MpiComm::world();
861+
int const commRank = world_comm.getRank();
862+
int const commSize = world_comm.getSize();
863+
if (commSize != processNum)
864+
{
865+
GTEST_SKIP() << " need " << processNum << " processes but got " << commSize << " mpi processes, skip test.";
866+
}
867+
868+
bool isContext = commRank == 0;
869+
bool isGeneration = commRank == 1;
870+
std::vector<int> participatntIds = {commRank};
871+
std::vector<int> deviceIds = {commRank};
872+
bool isController = (commRank == controllerRank);
873+
874+
OutputConfig outConfig(false, false, false, false, false, false);
875+
int const beamWidth = 1;
876+
BeamResult beamResult{beamWidth};
877+
878+
bool streaming = false;
879+
int const maxBeamWidth = 1;
880+
ASSERT_TRUE(fs::exists(DATA_PATH));
881+
882+
fs::path modelPath;
883+
// set defaults and adjust if needed by different models
884+
fs::path inputPath = DATA_PATH / "input_tokens.npy";
885+
ModelIds modelIds{50256, 50256};
886+
bool isSpeculativeDecoding{false};
887+
888+
// NOTE: This can be used to disable checks for certain prompt batch entries
889+
FlakyTestInfo flakyTestInfo;
890+
891+
if (modelName == "gpt")
892+
{
893+
auto const resultsPath
894+
= GPT_DATA_PATH / ((beamWidth == 1) ? "sampling" : "beam_search_" + std::to_string(beamWidth));
895+
modelPath = GPT_MODEL_PATH / PathUtil::FP16_GPT_ATTENTION_PACKED_PAGED_DIR() / "tp1-pp1-cp1-gpu";
896+
beamResult.resultsFile = resultsPath / PathUtil::FP16_PLUGIN_PACKED_PAGED_RESULT_FILE();
897+
}
898+
else if (modelName == "llama_tp1_pp1_cp1")
899+
{
900+
auto const resultsPath
901+
= LLAMA_DATA_PATH / ((beamWidth == 1) ? "sampling" : "beam_search_" + std::to_string(beamWidth));
902+
modelIds.padId = 2;
903+
modelIds.endId = 2;
904+
beamResult.resultsFile = resultsPath / PathUtil::FP16_PLUGIN_PACKED_PAGED_RESULT_TP1_PP1_FILE();
905+
modelPath = LLAMA_MODEL_PATH / PathUtil::FP16_GPT_ATTENTION_PACKED_PAGED_DIR() / "tp1-pp1-cp1-gpu";
906+
}
907+
else
908+
{
909+
TLLM_THROW("Unrecognized modelName");
910+
}
911+
912+
SizeType32 constexpr vocabSizePadded{50257}; // gpt vocabSizePadded
913+
914+
auto executorConfig = ExecutorConfig(maxBeamWidth);
915+
FloatType freeGpuMemoryFraction = 0.9f;
916+
KvCacheConfig kvCacheConfig{true, std::nullopt, std::nullopt, std::nullopt, freeGpuMemoryFraction};
917+
executorConfig.setKvCacheConfig(kvCacheConfig);
918+
executorConfig.setRequestStatsMaxIterations(1000);
919+
auto manager = tr::BufferManager(std::make_shared<tr::CudaStream>());
920+
auto const& givenInput = tr::utils::loadNpy(manager, inputPath.string(), tr::MemoryType::kCPU);
921+
auto [givenInputLengths, nbGivenInputs, maxInputLength] = getGivenInputLengths(*givenInput, modelIds.padId);
922+
world_comm.barrier();
923+
auto executor = tensorrt_llm::testing::disaggexecutor::DisaggExecutorLeader(modelPath, ModelType::kDECODER_ONLY,
924+
executorConfig, isController, isContext, isGeneration, givenInputLengths.size(), participatntIds, deviceIds,
925+
commRank);
926+
927+
std::unordered_map<IdType, SizeType32> reqIdToBatchId;
928+
std::unordered_map<SizeType32, std::vector<BeamTokens>> tokens;
929+
auto const* const givenInputData = tr::bufferCast<TokenIdType const>(*givenInput);
930+
931+
auto const& inputShape = givenInput->getShape();
932+
ASSERT_EQ(inputShape.nbDims, 2);
933+
ASSERT_GT(inputShape.d[0], 0);
934+
935+
// Load expected outputs for each beam width value
936+
auto testData = TestData::loadTestData(beamResult, *givenInput, beamWidth, manager, outConfig, modelIds);
937+
auto const maxSeqLen = testData.maxSeqLen;
938+
939+
// Load expected outputs and inputs
940+
SizeType32 numRequests = static_cast<SizeType32>(givenInputLengths.size());
941+
SizeType32 maxRequests = numRequests;
942+
std::vector<Request> requests;
943+
std::vector<SizeType32> reqMaxNewTokens;
944+
SizeType32 const numReturnSequences = 1;
945+
946+
for (SizeType32 req = 0; req < maxRequests; ++req)
947+
{
948+
SizeType32 inputLen = givenInputLengths.at(req);
949+
auto maxNewTokens = maxSeqLen - maxInputLength;
950+
reqMaxNewTokens.push_back(maxNewTokens);
951+
SizeType32 endId = -1;
952+
auto const* const seqBegin = givenInputData + req * maxInputLength;
953+
VecTokens tokens(seqBegin, seqBegin + inputLen);
954+
auto samplingConfig = tensorrt_llm::executor::SamplingConfig(beamWidth);
955+
samplingConfig.setNumReturnSequences(numReturnSequences);
956+
auto request = Request(
957+
VecTokens(seqBegin, seqBegin + inputLen), maxNewTokens, streaming, samplingConfig, outConfig, endId);
958+
request.setReturnAllGeneratedTokens(false);
959+
// setting request type to context/full by condition
960+
if (req % 2 == 0)
961+
{
962+
request.setRequestType(RequestType::REQUEST_TYPE_CONTEXT_ONLY);
963+
}
964+
else
965+
{
966+
request.setRequestType(RequestType::REQUEST_TYPE_CONTEXT_AND_GENERATION);
967+
}
968+
requests.emplace_back(std::move(request));
969+
}
970+
971+
if (isController)
972+
{
973+
std::vector<IdType> reqIds;
974+
975+
for (int i = 0; i < requests.size(); ++i)
976+
{
977+
std::vector<BeamTokens> resultTokens;
978+
resultTokens.reserve(numReturnSequences);
979+
for (SizeType32 seqIdx = 0; seqIdx < numReturnSequences; ++seqIdx)
980+
{
981+
resultTokens.emplace_back(beamWidth);
982+
}
983+
auto retReqId = executor.enqueueRequests({requests[i]});
984+
reqIds.push_back(retReqId.front());
985+
tokens[i] = std::move(resultTokens);
986+
reqIdToBatchId[retReqId.front()] = i;
987+
}
988+
989+
// Get the new tokens for each requests
990+
int32_t numFinished = 0;
991+
int iter = 0;
992+
SizeType32 numResponses = 0;
993+
while (numFinished < maxRequests && iter < mMaxWaitMs)
994+
{
995+
std::chrono::milliseconds waitTime(1);
996+
auto responses = executor.awaitResponses(waitTime);
997+
for (auto& response : responses)
998+
{
999+
numResponses++;
1000+
if (!response.hasError())
1001+
{
1002+
auto result = response.getResult();
1003+
numFinished += result.isFinal;
1004+
auto batchId = reqIdToBatchId.at(response.getRequestId());
1005+
auto seqIdx = result.sequenceIndex;
1006+
1007+
auto& outputTokenIds = result.outputTokenIds;
1008+
1009+
EXPECT_EQ(result.finishReasons.size(), beamWidth);
1010+
for (SizeType32 beam = 0; beam < beamWidth; ++beam)
1011+
{
1012+
auto& newTokens = outputTokenIds.at(beam);
1013+
auto& reqTokens = tokens.at(batchId).at(seqIdx).at(beam);
1014+
1015+
reqTokens.insert(reqTokens.end(), newTokens.begin(), newTokens.end());
1016+
// FinishReason is only supported for bw=1 and inflight batching.
1017+
if (beamWidth == 1 && executorConfig.getBatchingType() == BatchingType::kINFLIGHT)
1018+
{
1019+
EXPECT_EQ(result.finishReasons.at(beam),
1020+
result.isFinal ? FinishReason::kLENGTH : FinishReason::kNOT_FINISHED);
1021+
}
1022+
}
1023+
}
1024+
else
1025+
{
1026+
// Allow response with error only if awaitResponse processed a terminated request id
1027+
std::string err = "ReqId " + std::to_string(response.getRequestId())
1028+
+ " has already been processed and was terminated.";
1029+
EXPECT_EQ(response.getErrorMsg(), err);
1030+
}
1031+
}
1032+
++iter;
1033+
}
1034+
EXPECT_LT(iter, mMaxWaitMs);
1035+
testData.verifyOutput(tokens, givenInputLengths, nbGivenInputs, streaming, outConfig.excludeInputFromOutput,
1036+
flakyTestInfo, isSpeculativeDecoding, false, beamWidth, numReturnSequences, false);
1037+
}
1038+
world_comm.barrier();
1039+
#else
1040+
GTEST_SKIP() << "Skipping DisaggExecutor Test";
1041+
#endif
1042+
}
1043+
8361044
INSTANTIATE_TEST_SUITE_P(GptDisaggSymmetricExecutorTest, DisaggParamsTest,
8371045
testing::Combine(testing::Values(2), testing::Values(std::vector<std::string>{"gpt", "gpt"}),
8381046
testing::Values(std::vector<std::vector<int>>{{0}, {1}}),
@@ -868,6 +1076,9 @@ INSTANTIATE_TEST_SUITE_P(GptSingleDeviceDisaggSymmetricExecutorMixedTest, Disagg
8681076
testing::Values(1)),
8691077
generateTestNameDisaggParams);
8701078

1079+
INSTANTIATE_TEST_SUITE_P(ConditionalDisaggExecutorTest, ConditionalDisaggParamsTest,
1080+
testing::Combine(testing::Values("gpt", "llama_tp1_pp1_cp1")), generateTestNameCondDisaggParams);
1081+
8711082
INSTANTIATE_TEST_SUITE_P(LlamaTP2DisaggSymmetricExecutorTest, DisaggParamsTest,
8721083
testing::Combine(testing::Values(4),
8731084
testing::Values(std::vector<std::string>{"llama_tp2_pp1_cp1", "llama_tp2_pp1_cp1"}),

0 commit comments

Comments
 (0)