@@ -373,11 +373,11 @@ void PeftCacheManager::addRequestPeft(std::shared_ptr<LlmRequest> llmRequest, bo
373373 TLLM_LOG_DEBUG (" %s stop" , __PRETTY_FUNCTION__);
374374}
375375
376- std::tuple<std::map <uint64_t , std::future<void >>, std::map< uint64_t , std::vector< uint64_t >> >
376+ std::tuple<std::unordered_map <uint64_t , std::future<void >>, BasePeftCacheManager::TaskIdToReqIds >
377377PeftCacheManager::getTaskMaps (RequestVector const & contextRequests, RequestVector const & generationRequests)
378378{
379- std::map< uint64_t , std::vector< uint64_t >> taskIdToReqIds;
380- std::map <uint64_t , std::future<void >> taskIdToFuture;
379+ TaskIdToReqIds taskIdToReqIds;
380+ std::unordered_map <uint64_t , std::future<void >> taskIdToFuture;
381381 std::lock_guard<std::mutex> futuresLock (mPutFuturesMutex );
382382 for (auto const & requests : {contextRequests, generationRequests})
383383 {
@@ -415,7 +415,7 @@ PeftCacheManager::getTaskMaps(RequestVector const& contextRequests, RequestVecto
415415 return {std::move (taskIdToFuture), taskIdToReqIds};
416416}
417417
418- PeftCacheManager::PeftTable PeftCacheManager::ensureBatch (
418+ PeftCacheManager::EnsureBatchTaskResult PeftCacheManager::ensureBatchMapTaskId (
419419 RequestVector const & contextRequests, RequestVector const & generationRequests, bool resetGpuCache)
420420{
421421 TLLM_LOG_DEBUG (" %s start" , __PRETTY_FUNCTION__);
@@ -426,7 +426,7 @@ PeftCacheManager::PeftTable PeftCacheManager::ensureBatch(
426426 auto [taskIdToFuture_, taskIdToReqIds] = getTaskMaps (contextRequests, generationRequests);
427427 auto taskIdToFuture = std::move (taskIdToFuture_); // captured structured bindings are a C++20 extension
428428
429- std::map <uint64_t , std::future<std::vector<runtime::LoraCache::TaskLayerModuleConfig>>> ensureFutures;
429+ std::unordered_map <uint64_t , std::future<std::vector<runtime::LoraCache::TaskLayerModuleConfig>>> ensureFutures;
430430 for (auto & [taskId, taskFuture] : taskIdToFuture)
431431 {
432432 auto fn = [&taskIdToFuture, taskId = taskId, this ]() -> std::vector<runtime::LoraCache::TaskLayerModuleConfig>
@@ -457,18 +457,31 @@ PeftCacheManager::PeftTable PeftCacheManager::ensureBatch(
457457 ensureFutures.try_emplace (taskId, std::move (f));
458458 }
459459
460- PeftTable peftTable{};
460+ TaskPeftTable peftTable{};
461461 for (auto const & [taskId, reqIds] : taskIdToReqIds)
462462 {
463463 auto && f = ensureFutures.at (taskId);
464464 auto const values = f.get ();
465- for (auto const & reqId : reqIds)
465+ peftTable.try_emplace (taskId, values);
466+ }
467+ TLLM_LOG_DEBUG (" %s stop" , __PRETTY_FUNCTION__);
468+ return {std::move (peftTable), std::move (taskIdToReqIds)};
469+ }
470+
471+ PeftCacheManager::PeftTable PeftCacheManager::ensureBatch (
472+ RequestVector const & contextRequests, RequestVector const & generationRequests, bool resetGpuCache)
473+ {
474+ auto [taskTable, taskIdToReqIds] = ensureBatchMapTaskId (contextRequests, generationRequests, resetGpuCache);
475+ PeftTable requestTable{};
476+ for (auto const & [taskId, values] : taskTable)
477+ {
478+ auto const & reqIds = taskIdToReqIds.at (taskId);
479+ for (auto const reqId : reqIds)
466480 {
467- peftTable .try_emplace (reqId, values);
481+ requestTable .try_emplace (reqId, values);
468482 }
469483 }
470- TLLM_LOG_DEBUG (" %s stop" , __PRETTY_FUNCTION__);
471- return peftTable;
484+ return requestTable;
472485}
473486
474487bool PeftCacheManager::isTaskCached (uint64_t taskId) const
@@ -486,6 +499,11 @@ bool PeftCacheManager::isTaskDoneDevice(uint64_t taskId) const
486499 return mDeviceLoraCache ->isDone (taskId);
487500}
488501
502+ bool PeftCacheManager::isTaskCachedDevice (uint64_t const taskId) const
503+ {
504+ return mDeviceLoraCache ->has (taskId);
505+ }
506+
489507void PeftCacheManager::updateTaskState (uint64_t taskId, uint64_t reqId, bool terminate, bool pause)
490508{
491509 if (!terminate)
@@ -645,3 +663,5 @@ SizeType32 NoOpPeftCacheManager::determineNumPages(std::shared_ptr<LlmRequest> l
645663 return 0 ;
646664}
647665} // namespace tensorrt_llm::batch_manager
666+
667+ // TODO: merge C++ LoRA caching status with Py Slot manager
0 commit comments