Skip to content

Commit 6703df9

Browse files
authored
Merge branch 'main' into user/williamj/use-flashinfer-mamba-kernel
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
2 parents 4ff479f + 9beb971 commit 6703df9

File tree

104 files changed

+5377
-1313
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

104 files changed

+5377
-1313
lines changed

ATTRIBUTIONS-Python.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5261,7 +5261,7 @@ For more information, please refer to <http://unlicense.org>
52615261
- `Tracker`: https://github.com/tox-dev/py-filelock/issues
52625262

52635263

5264-
## flashinfer-python (0.6.0)
5264+
## flashinfer-python (0.6.1)
52655265

52665266
### Licenses
52675267
License: `Apache-2.0`

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ This branch is a prototype and not stable for production use. PRs are not accept
1313
[![python](https://img.shields.io/badge/python-3.10-green)](https://www.python.org/downloads/release/python-31012/)
1414
[![cuda](https://img.shields.io/badge/cuda-13.1.0-green)](https://developer.nvidia.com/cuda-downloads)
1515
[![torch](https://img.shields.io/badge/torch-2.9.1-green)](https://pytorch.org)
16-
[![version](https://img.shields.io/badge/release-1.3.0rc0-green)](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/version.py)
16+
[![version](https://img.shields.io/badge/release-1.3.0rc1-green)](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/version.py)
1717
[![license](https://img.shields.io/badge/license-Apache%202-blue)](https://github.com/NVIDIA/TensorRT-LLM/blob/main/LICENSE)
1818

1919
[Architecture](https://nvidia.github.io/TensorRT-LLM/developer-guide/overview.html)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Performance](https://nvidia.github.io/TensorRT-LLM/developer-guide/perf-overview.html)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Examples](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Documentation](https://nvidia.github.io/TensorRT-LLM/)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Roadmap](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap)

cpp/include/tensorrt_llm/batch_manager/peftCacheManager.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,10 @@ class BasePeftCacheManager
5757
public:
5858
using LlmRequestPtr = std::shared_ptr<LlmRequest>;
5959
using RequestVector = std::vector<LlmRequestPtr>;
60-
using PeftTable = std::map<uint64_t, std::vector<runtime::LoraCache::TaskLayerModuleConfig>>;
60+
using PeftTable = std::unordered_map<uint64_t, std::vector<runtime::LoraCache::TaskLayerModuleConfig>>;
61+
using TaskPeftTable = std::unordered_map<uint64_t, std::vector<runtime::LoraCache::TaskLayerModuleConfig>>;
62+
using TaskIdToReqIds = std::unordered_map<uint64_t, std::vector<uint64_t>>;
63+
using EnsureBatchTaskResult = std::tuple<TaskPeftTable, TaskIdToReqIds>;
6164

6265
virtual ~BasePeftCacheManager() = default;
6366

@@ -99,6 +102,8 @@ class BasePeftCacheManager
99102
class PeftCacheManager : public BasePeftCacheManager
100103
{
101104
public:
105+
using EnsureBatchTaskResult = BasePeftCacheManager::EnsureBatchTaskResult;
106+
102107
PeftCacheManager(PeftCacheManagerConfig const& config, runtime::ModelConfig const& modelConfig,
103108
runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager);
104109

@@ -109,12 +114,17 @@ class PeftCacheManager : public BasePeftCacheManager
109114
PeftTable ensureBatch(RequestVector const& contextRequests, RequestVector const& generationRequests,
110115
bool resetGpuCache = false) override;
111116

117+
EnsureBatchTaskResult ensureBatchMapTaskId(
118+
RequestVector const& contextRequests, RequestVector const& generationRequests, bool resetGpuCache = false);
119+
112120
[[nodiscard]] bool isTaskCached(uint64_t taskId) const;
113121

114122
[[nodiscard]] bool isTaskDone(uint64_t taskId) const;
115123

116124
[[nodiscard]] bool isTaskDoneDevice(uint64_t taskId) const;
117125

126+
[[nodiscard]] bool isTaskCachedDevice(uint64_t const taskId) const;
127+
118128
void resetDeviceCache() override;
119129

120130
void markRequestDone(LlmRequest const& llmReq, bool pause = false) override;
@@ -159,7 +169,7 @@ class PeftCacheManager : public BasePeftCacheManager
159169
std::unordered_map<uint64_t, std::unordered_set<uint64_t>> mTaskIdToReqIds;
160170
std::unordered_map<uint64_t, std::unordered_set<uint64_t>> mTaskIdToPausedReqIds;
161171

162-
std::tuple<std::map<uint64_t, std::future<void>>, std::map<uint64_t, std::vector<uint64_t>>> getTaskMaps(
172+
std::tuple<std::unordered_map<uint64_t, std::future<void>>, TaskIdToReqIds> getTaskMaps(
163173
RequestVector const& contextRequests, RequestVector const& generationRequests);
164174

165175
runtime::ModelConfig mModelConfig;

cpp/include/tensorrt_llm/executor/executor.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,7 @@ class Request
684684
/// finish reason. The request may exceed this time slightly, but at most by 1 forward pass (in pipeline parallelism
685685
/// that may involve multiple micro-batches). A request can be timed-out before ever being scheduled.
686686
/// @param cacheSaltID Salt ID for KV cache blocks to limit the kv cache reuse to the requests with the same string.
687+
/// @param disaggRequestId Disaggregated request ID.
687688
Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming = false,
688689
SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(),
689690
std::optional<SizeType32> const& endId = std::nullopt, std::optional<SizeType32> const& padId = std::nullopt,
@@ -711,7 +712,8 @@ class Request
711712
std::optional<GuidedDecodingParams> guidedDecodingParams = std::nullopt,
712713
std::optional<SizeType32> languageAdapterUid = std::nullopt,
713714
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
714-
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt);
715+
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt,
716+
std::optional<IdType> disaggRequestId = std::nullopt);
715717

716718
/// @brief This logits postprocessor name will dispatch to the batched logits postprocessor
717719
static auto constexpr kBatchedPostProcessorName = "batched";
@@ -761,6 +763,7 @@ class Request
761763
[[nodiscard]] std::optional<MillisecondsType> getAllottedTimeMs() const;
762764
[[nodiscard]] std::optional<CacheSaltIDType> getCacheSaltID() const;
763765
[[nodiscard]] std::optional<std::vector<std::string>> getAdditionalOutputNames() const;
766+
[[nodiscard]] std::optional<IdType> getDisaggRequestId() const;
764767

765768
void setStreaming(bool streaming);
766769
void setSamplingConfig(SamplingConfig const& config);
@@ -796,6 +799,7 @@ class Request
796799
void setLanguageAdapterUid(SizeType32 languageAdapterUid);
797800
void setAllottedTimeMs(MillisecondsType allottedTimeMs);
798801
void setCacheSaltID(CacheSaltIDType cacheSaltID);
802+
void setDisaggRequestId(IdType disaggRequestId);
799803

800804
private:
801805
friend class Serialization;

cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -373,11 +373,11 @@ void PeftCacheManager::addRequestPeft(std::shared_ptr<LlmRequest> llmRequest, bo
373373
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
374374
}
375375

376-
std::tuple<std::map<uint64_t, std::future<void>>, std::map<uint64_t, std::vector<uint64_t>>>
376+
std::tuple<std::unordered_map<uint64_t, std::future<void>>, BasePeftCacheManager::TaskIdToReqIds>
377377
PeftCacheManager::getTaskMaps(RequestVector const& contextRequests, RequestVector const& generationRequests)
378378
{
379-
std::map<uint64_t, std::vector<uint64_t>> taskIdToReqIds;
380-
std::map<uint64_t, std::future<void>> taskIdToFuture;
379+
TaskIdToReqIds taskIdToReqIds;
380+
std::unordered_map<uint64_t, std::future<void>> taskIdToFuture;
381381
std::lock_guard<std::mutex> futuresLock(mPutFuturesMutex);
382382
for (auto const& requests : {contextRequests, generationRequests})
383383
{
@@ -415,7 +415,7 @@ PeftCacheManager::getTaskMaps(RequestVector const& contextRequests, RequestVecto
415415
return {std::move(taskIdToFuture), taskIdToReqIds};
416416
}
417417

418-
PeftCacheManager::PeftTable PeftCacheManager::ensureBatch(
418+
PeftCacheManager::EnsureBatchTaskResult PeftCacheManager::ensureBatchMapTaskId(
419419
RequestVector const& contextRequests, RequestVector const& generationRequests, bool resetGpuCache)
420420
{
421421
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
@@ -426,7 +426,7 @@ PeftCacheManager::PeftTable PeftCacheManager::ensureBatch(
426426
auto [taskIdToFuture_, taskIdToReqIds] = getTaskMaps(contextRequests, generationRequests);
427427
auto taskIdToFuture = std::move(taskIdToFuture_); // captured structured bindings are a C++20 extension
428428

429-
std::map<uint64_t, std::future<std::vector<runtime::LoraCache::TaskLayerModuleConfig>>> ensureFutures;
429+
std::unordered_map<uint64_t, std::future<std::vector<runtime::LoraCache::TaskLayerModuleConfig>>> ensureFutures;
430430
for (auto& [taskId, taskFuture] : taskIdToFuture)
431431
{
432432
auto fn = [&taskIdToFuture, taskId = taskId, this]() -> std::vector<runtime::LoraCache::TaskLayerModuleConfig>
@@ -457,18 +457,31 @@ PeftCacheManager::PeftTable PeftCacheManager::ensureBatch(
457457
ensureFutures.try_emplace(taskId, std::move(f));
458458
}
459459

460-
PeftTable peftTable{};
460+
TaskPeftTable peftTable{};
461461
for (auto const& [taskId, reqIds] : taskIdToReqIds)
462462
{
463463
auto&& f = ensureFutures.at(taskId);
464464
auto const values = f.get();
465-
for (auto const& reqId : reqIds)
465+
peftTable.try_emplace(taskId, values);
466+
}
467+
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
468+
return {std::move(peftTable), std::move(taskIdToReqIds)};
469+
}
470+
471+
PeftCacheManager::PeftTable PeftCacheManager::ensureBatch(
472+
RequestVector const& contextRequests, RequestVector const& generationRequests, bool resetGpuCache)
473+
{
474+
auto [taskTable, taskIdToReqIds] = ensureBatchMapTaskId(contextRequests, generationRequests, resetGpuCache);
475+
PeftTable requestTable{};
476+
for (auto const& [taskId, values] : taskTable)
477+
{
478+
auto const& reqIds = taskIdToReqIds.at(taskId);
479+
for (auto const reqId : reqIds)
466480
{
467-
peftTable.try_emplace(reqId, values);
481+
requestTable.try_emplace(reqId, values);
468482
}
469483
}
470-
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
471-
return peftTable;
484+
return requestTable;
472485
}
473486

474487
bool PeftCacheManager::isTaskCached(uint64_t taskId) const
@@ -486,6 +499,11 @@ bool PeftCacheManager::isTaskDoneDevice(uint64_t taskId) const
486499
return mDeviceLoraCache->isDone(taskId);
487500
}
488501

502+
bool PeftCacheManager::isTaskCachedDevice(uint64_t const taskId) const
503+
{
504+
return mDeviceLoraCache->has(taskId);
505+
}
506+
489507
void PeftCacheManager::updateTaskState(uint64_t taskId, uint64_t reqId, bool terminate, bool pause)
490508
{
491509
if (!terminate)
@@ -645,3 +663,5 @@ SizeType32 NoOpPeftCacheManager::determineNumPages(std::shared_ptr<LlmRequest> l
645663
return 0;
646664
}
647665
} // namespace tensorrt_llm::batch_manager
666+
667+
// TODO: merge C++ LoRA caching status with Py Slot manager

cpp/tensorrt_llm/executor/executorImpl.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,7 @@ std::vector<IdType> Executor::Impl::enqueueRequests(common::ArrayView<Request co
907907
auto now = std::chrono::steady_clock::now();
908908
for (auto const& req : requests)
909909
{
910-
ids.emplace_back(generateReqId());
910+
ids.emplace_back(generateReqId(req));
911911
TLLM_LOG_DEBUG("Enqueue new request with id %d", ids.back());
912912

913913
std::vector<IdType> childReqIds;
@@ -917,7 +917,7 @@ std::vector<IdType> Executor::Impl::enqueueRequests(common::ArrayView<Request co
917917
childReqIds.reserve(numChildRequests);
918918
for (int childId = 0; childId < numChildRequests; childId++)
919919
{
920-
childReqIds.emplace_back(generateReqId());
920+
childReqIds.emplace_back(generateLocalReqId());
921921
TLLM_LOG_DEBUG("Add new child request with id %d", childReqIds.back());
922922
}
923923
}
@@ -1319,7 +1319,7 @@ std::vector<RequestWithId> Executor::Impl::getLeaderNewReqWithIds(
13191319
return reqWithIds;
13201320
}
13211321

1322-
if (mQueuedRequests.front().id == mTerminateReqId)
1322+
if (mQueuedRequests.front().id == kTerminateReqId)
13231323
{
13241324
reqWithIds.emplace_back(std::move(mQueuedRequests.front()));
13251325
mQueuedRequests.pop_front();
@@ -1468,7 +1468,7 @@ std::tuple<Executor::Impl::RequestList, double> Executor::Impl::fetchNewRequests
14681468
double newActiveRequestsQueueLatencyMS{0.};
14691469
for (auto& reqWithId : reqWithIds)
14701470
{
1471-
if (reqWithId.id == mTerminateReqId)
1471+
if (reqWithId.id == kTerminateReqId)
14721472
{
14731473
mShutdown = true;
14741474
mResponsesCv.notify_all();
@@ -2357,7 +2357,6 @@ void Executor::Impl::executionLoop()
23572357
}
23582358
}
23592359
}
2360-
23612360
if (!activeRequests.empty())
23622361
{
23632362
forwardAsync(activeRequests);
@@ -2411,7 +2410,7 @@ void Executor::Impl::enqueueTerminateRequest()
24112410
{
24122411
std::scoped_lock<std::mutex> lck(mQueuedReqMtx);
24132412
Request dummyReq({1}, 1);
2414-
RequestWithId reqWithId{std::move(dummyReq), mTerminateReqId};
2413+
RequestWithId reqWithId{std::move(dummyReq), kTerminateReqId};
24152414
mQueuedRequests.emplace_back(reqWithId);
24162415
}
24172416
mQueuedReqCv.notify_one();

cpp/tensorrt_llm/executor/executorImpl.h

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,20 @@ class Executor::Impl
178178

179179
void initializeLogitsPostProcessorBatched(LogitsPostProcessorConfig const& logitsProcConfig);
180180

181-
IdType generateReqId()
181+
IdType generateReqId(Request const& request)
182182
{
183-
return (mLastReqId++ % UINT64_MAX);
183+
// If the request has a disaggregated request id, prefer it.
184+
if (request.getDisaggRequestId().has_value() && request.getDisaggRequestId().value() > kMaxLocalReqId)
185+
{
186+
return request.getDisaggRequestId().value();
187+
}
188+
// Otherwise, generate a local request id in range [1, kMaxLocalReqId).
189+
return generateLocalReqId();
190+
}
191+
192+
IdType generateLocalReqId()
193+
{
194+
return (mLastReqId++ % kMaxLocalReqId);
184195
}
185196

186197
std::vector<RequestWithId> getLeaderNewReqWithIds(
@@ -315,7 +326,10 @@ class Executor::Impl
315326

316327
IdType mLastReqId = 1;
317328

318-
static constexpr IdType mTerminateReqId = 0;
329+
static constexpr IdType kTerminateReqId = 0;
330+
// Request id > kMaxLocalReqId is reserved for disaggregated requests.
331+
// This max ID is also in Python side.
332+
static constexpr IdType kMaxLocalReqId = 1ULL << 42U;
319333

320334
BatchingType mBatchingType;
321335
bool mIsSchedulerMaxUtilization;

cpp/tensorrt_llm/executor/request.cpp

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming,
4040
std::optional<SizeType32> encoderOutputLength, std::optional<Tensor> crossAttentionMask,
4141
SizeType32 numReturnSequences, std::optional<EagleConfig> eagleConfig, std::optional<Tensor> skipCrossAttnBlocks,
4242
std::optional<GuidedDecodingParams> guidedDecodingParams, std::optional<SizeType32> languageAdapterUid,
43-
std::optional<MillisecondsType> allottedTimeMs, std::optional<CacheSaltIDType> cacheSaltID)
43+
std::optional<MillisecondsType> allottedTimeMs, std::optional<CacheSaltIDType> cacheSaltID,
44+
std::optional<IdType> disaggRequestId)
4445
: mImpl(std::make_unique<Impl>(std::move(inputTokenIds), maxTokens, streaming, samplingConfig, outputConfig, endId,
4546
padId, std::move(positionIds), std::move(badWords), std::move(stopWords), std::move(embeddingBias),
4647
std::move(externalDraftTokensConfig), std::move(pTuningConfig), std::move(multimodalInput),
@@ -49,7 +50,7 @@ Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming,
4950
std::move(encoderInputTokenIds), clientId, returnAllGeneratedTokens, priority, type,
5051
std::move(contextPhaseParams), std::move(encoderInputFeatures), encoderOutputLength, crossAttentionMask,
5152
numReturnSequences, eagleConfig, skipCrossAttnBlocks, std::move(guidedDecodingParams), languageAdapterUid,
52-
allottedTimeMs, cacheSaltID))
53+
allottedTimeMs, cacheSaltID, disaggRequestId))
5354
{
5455
}
5556

@@ -253,6 +254,11 @@ std::optional<CacheSaltIDType> Request::getCacheSaltID() const
253254
return mImpl->getCacheSaltID();
254255
}
255256

257+
std::optional<IdType> Request::getDisaggRequestId() const
258+
{
259+
return mImpl->getDisaggRequestId();
260+
}
261+
256262
void Request::setStreaming(bool streaming)
257263
{
258264
mImpl->setStreaming(streaming);
@@ -310,12 +316,12 @@ void Request::setPromptTuningConfig(PromptTuningConfig const& pTuningConfig)
310316

311317
void Request::setMultimodalEmbedding(Tensor const& multimodalEmbedding)
312318
{
313-
return mImpl->setMultimodalEmbedding(multimodalEmbedding);
319+
mImpl->setMultimodalEmbedding(multimodalEmbedding);
314320
}
315321

316322
void Request::setMultimodalInput(MultimodalInput const& multimodalInput)
317323
{
318-
return mImpl->setMultimodalInput(multimodalInput);
324+
mImpl->setMultimodalInput(multimodalInput);
319325
}
320326

321327
void Request::setMropeConfig(MropeConfig const& mRopeConfig)
@@ -400,7 +406,7 @@ void Request::setEagleConfig(std::optional<EagleConfig> const& eagleConfig)
400406

401407
void Request::setSkipCrossAttnBlocks(Tensor skipCrossAttnBlocks)
402408
{
403-
return mImpl->setSkipCrossAttnBlocks(skipCrossAttnBlocks);
409+
mImpl->setSkipCrossAttnBlocks(skipCrossAttnBlocks);
404410
}
405411

406412
void Request::setGuidedDecodingParams(GuidedDecodingParams const& guidedDecodingParams)
@@ -410,16 +416,21 @@ void Request::setGuidedDecodingParams(GuidedDecodingParams const& guidedDecoding
410416

411417
void Request::setAllottedTimeMs(MillisecondsType allottedTimeMs)
412418
{
413-
return mImpl->setAllottedTimeMs(allottedTimeMs);
419+
mImpl->setAllottedTimeMs(allottedTimeMs);
414420
}
415421

416422
void Request::setLanguageAdapterUid(SizeType32 languageAdapterUid)
417423
{
418-
return mImpl->setLanguageAdapterUid(languageAdapterUid);
424+
mImpl->setLanguageAdapterUid(languageAdapterUid);
419425
}
420426

421427
void Request::setCacheSaltID(CacheSaltIDType cacheSaltID)
422428
{
423-
return mImpl->setCacheSaltID(cacheSaltID);
429+
mImpl->setCacheSaltID(cacheSaltID);
430+
}
431+
432+
void Request::setDisaggRequestId(IdType disaggRequestId)
433+
{
434+
mImpl->setDisaggRequestId(disaggRequestId);
424435
}
425436
} // namespace tensorrt_llm::executor

0 commit comments

Comments
 (0)