Skip to content

Commit 1dc49b2

Browse files
authored
[https://nvbugs/5322131][feat] Multi-LoRA serving with CUDA Graph (#8279)
Signed-off-by: Jiayu Chang <jiayuc@nvidia.com>
1 parent cdb9ffd commit 1dc49b2

File tree

25 files changed

+2766
-172
lines changed

25 files changed

+2766
-172
lines changed

cpp/include/tensorrt_llm/batch_manager/peftCacheManager.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,10 @@ class BasePeftCacheManager
5757
public:
5858
using LlmRequestPtr = std::shared_ptr<LlmRequest>;
5959
using RequestVector = std::vector<LlmRequestPtr>;
60-
using PeftTable = std::map<uint64_t, std::vector<runtime::LoraCache::TaskLayerModuleConfig>>;
60+
using PeftTable = std::unordered_map<uint64_t, std::vector<runtime::LoraCache::TaskLayerModuleConfig>>;
61+
using TaskPeftTable = std::unordered_map<uint64_t, std::vector<runtime::LoraCache::TaskLayerModuleConfig>>;
62+
using TaskIdToReqIds = std::unordered_map<uint64_t, std::vector<uint64_t>>;
63+
using EnsureBatchTaskResult = std::tuple<TaskPeftTable, TaskIdToReqIds>;
6164

6265
virtual ~BasePeftCacheManager() = default;
6366

@@ -99,6 +102,8 @@ class BasePeftCacheManager
99102
class PeftCacheManager : public BasePeftCacheManager
100103
{
101104
public:
105+
using EnsureBatchTaskResult = BasePeftCacheManager::EnsureBatchTaskResult;
106+
102107
PeftCacheManager(PeftCacheManagerConfig const& config, runtime::ModelConfig const& modelConfig,
103108
runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager);
104109

@@ -109,12 +114,17 @@ class PeftCacheManager : public BasePeftCacheManager
109114
PeftTable ensureBatch(RequestVector const& contextRequests, RequestVector const& generationRequests,
110115
bool resetGpuCache = false) override;
111116

117+
EnsureBatchTaskResult ensureBatchMapTaskId(
118+
RequestVector const& contextRequests, RequestVector const& generationRequests, bool resetGpuCache = false);
119+
112120
[[nodiscard]] bool isTaskCached(uint64_t taskId) const;
113121

114122
[[nodiscard]] bool isTaskDone(uint64_t taskId) const;
115123

116124
[[nodiscard]] bool isTaskDoneDevice(uint64_t taskId) const;
117125

126+
[[nodiscard]] bool isTaskCachedDevice(uint64_t const taskId) const;
127+
118128
void resetDeviceCache() override;
119129

120130
void markRequestDone(LlmRequest const& llmReq, bool pause = false) override;
@@ -159,7 +169,7 @@ class PeftCacheManager : public BasePeftCacheManager
159169
std::unordered_map<uint64_t, std::unordered_set<uint64_t>> mTaskIdToReqIds;
160170
std::unordered_map<uint64_t, std::unordered_set<uint64_t>> mTaskIdToPausedReqIds;
161171

162-
std::tuple<std::map<uint64_t, std::future<void>>, std::map<uint64_t, std::vector<uint64_t>>> getTaskMaps(
172+
std::tuple<std::unordered_map<uint64_t, std::future<void>>, TaskIdToReqIds> getTaskMaps(
163173
RequestVector const& contextRequests, RequestVector const& generationRequests);
164174

165175
runtime::ModelConfig mModelConfig;

cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -373,11 +373,11 @@ void PeftCacheManager::addRequestPeft(std::shared_ptr<LlmRequest> llmRequest, bo
373373
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
374374
}
375375

376-
std::tuple<std::map<uint64_t, std::future<void>>, std::map<uint64_t, std::vector<uint64_t>>>
376+
std::tuple<std::unordered_map<uint64_t, std::future<void>>, BasePeftCacheManager::TaskIdToReqIds>
377377
PeftCacheManager::getTaskMaps(RequestVector const& contextRequests, RequestVector const& generationRequests)
378378
{
379-
std::map<uint64_t, std::vector<uint64_t>> taskIdToReqIds;
380-
std::map<uint64_t, std::future<void>> taskIdToFuture;
379+
TaskIdToReqIds taskIdToReqIds;
380+
std::unordered_map<uint64_t, std::future<void>> taskIdToFuture;
381381
std::lock_guard<std::mutex> futuresLock(mPutFuturesMutex);
382382
for (auto const& requests : {contextRequests, generationRequests})
383383
{
@@ -415,7 +415,7 @@ PeftCacheManager::getTaskMaps(RequestVector const& contextRequests, RequestVecto
415415
return {std::move(taskIdToFuture), taskIdToReqIds};
416416
}
417417

418-
PeftCacheManager::PeftTable PeftCacheManager::ensureBatch(
418+
PeftCacheManager::EnsureBatchTaskResult PeftCacheManager::ensureBatchMapTaskId(
419419
RequestVector const& contextRequests, RequestVector const& generationRequests, bool resetGpuCache)
420420
{
421421
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
@@ -426,7 +426,7 @@ PeftCacheManager::PeftTable PeftCacheManager::ensureBatch(
426426
auto [taskIdToFuture_, taskIdToReqIds] = getTaskMaps(contextRequests, generationRequests);
427427
auto taskIdToFuture = std::move(taskIdToFuture_); // captured structured bindings are a C++20 extension
428428

429-
std::map<uint64_t, std::future<std::vector<runtime::LoraCache::TaskLayerModuleConfig>>> ensureFutures;
429+
std::unordered_map<uint64_t, std::future<std::vector<runtime::LoraCache::TaskLayerModuleConfig>>> ensureFutures;
430430
for (auto& [taskId, taskFuture] : taskIdToFuture)
431431
{
432432
auto fn = [&taskIdToFuture, taskId = taskId, this]() -> std::vector<runtime::LoraCache::TaskLayerModuleConfig>
@@ -457,18 +457,31 @@ PeftCacheManager::PeftTable PeftCacheManager::ensureBatch(
457457
ensureFutures.try_emplace(taskId, std::move(f));
458458
}
459459

460-
PeftTable peftTable{};
460+
TaskPeftTable peftTable{};
461461
for (auto const& [taskId, reqIds] : taskIdToReqIds)
462462
{
463463
auto&& f = ensureFutures.at(taskId);
464464
auto const values = f.get();
465-
for (auto const& reqId : reqIds)
465+
peftTable.try_emplace(taskId, values);
466+
}
467+
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
468+
return {std::move(peftTable), std::move(taskIdToReqIds)};
469+
}
470+
471+
PeftCacheManager::PeftTable PeftCacheManager::ensureBatch(
472+
RequestVector const& contextRequests, RequestVector const& generationRequests, bool resetGpuCache)
473+
{
474+
auto [taskTable, taskIdToReqIds] = ensureBatchMapTaskId(contextRequests, generationRequests, resetGpuCache);
475+
PeftTable requestTable{};
476+
for (auto const& [taskId, values] : taskTable)
477+
{
478+
auto const& reqIds = taskIdToReqIds.at(taskId);
479+
for (auto const reqId : reqIds)
466480
{
467-
peftTable.try_emplace(reqId, values);
481+
requestTable.try_emplace(reqId, values);
468482
}
469483
}
470-
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
471-
return peftTable;
484+
return requestTable;
472485
}
473486

474487
bool PeftCacheManager::isTaskCached(uint64_t taskId) const
@@ -486,6 +499,11 @@ bool PeftCacheManager::isTaskDoneDevice(uint64_t taskId) const
486499
return mDeviceLoraCache->isDone(taskId);
487500
}
488501

502+
bool PeftCacheManager::isTaskCachedDevice(uint64_t const taskId) const
503+
{
504+
return mDeviceLoraCache->has(taskId);
505+
}
506+
489507
void PeftCacheManager::updateTaskState(uint64_t taskId, uint64_t reqId, bool terminate, bool pause)
490508
{
491509
if (!terminate)
@@ -645,3 +663,5 @@ SizeType32 NoOpPeftCacheManager::determineNumPages(std::shared_ptr<LlmRequest> l
645663
return 0;
646664
}
647665
} // namespace tensorrt_llm::batch_manager
666+
667+
// TODO: merge C++ LoRA caching status with Py Slot manager

0 commit comments

Comments
 (0)