Skip to content

Commit 1d61b14

Browse files
committed
This is a combination of 16 commits.
Implement CUDA Graph compatible multi LoRAs Signed-off-by: Jiayu Chang <jiayuc@nvidia.com> Refactor CUDA Graph LoRA integration to support precomputed leading dimensions - Updated `cuda_graph_grouped_gemm` and `cuda_graph_splitk_grouped_gemm` functions to accept leading dimension pointers for A, B, C, and D matrices. - Modified `LoraImpl` to retrieve and pass leading dimension pointers during GEMM operations. - Enhanced `CudaGraphLoraParams` to manage leading dimensions for each layer and module. - Adjusted `CudaGraphLoraManager` to initialize parameters based on actual layer configurations from the PEFT table. - Improved handling of layer-specific parameters to ensure compatibility with CUDA Graph operations. This refactor aims to optimize performance by leveraging precomputed leading dimensions, reducing overhead during GEMM execution. Signed-off-by: Jiayu Chang <jiayuc@nvidia.com> bug fixes Signed-off-by: Jiayu Chang <jiayuc@nvidia.com> Move input prep to graph Signed-off-by: Jiayu Chang <jiayuc@nvidia.com> Fix bug in adapter size Signed-off-by: Jiayu Chang <jiayuc@nvidia.com> Pass all but `test_llama_7b_lora_config_overrides_peft_cache_config` on L40s Graph seems to capture code outside of the captured function? Signed-off-by: Jiayu Chang <jiayuc@nvidia.com> Pass all tests Signed-off-by: Jiayu Chang <jiayuc@nvidia.com> sync slot manager with c++ Signed-off-by: Jiayu Chang <jiayuc@nvidia.com> Update kernel alignment selection Signed-off-by: Jiayu Chang <jiayuc@nvidia.com> Fix kernel workspace sizes Signed-off-by: Jiayu Chang <jiayuc@nvidia.com> memcpy use pinned memory; remove assert in slot manager eviction Signed-off-by: Jiayu Chang <jiayuc@nvidia.com> Add param fill fused kernel Signed-off-by: Jiayu Chang <jiayuc@nvidia.com> Disable torch nvtx emit Signed-off-by: Jiayu Chang <jiayuc@nvidia.com> Disable init manager without cuda graph Signed-off-by: Jiayu Chang <jiayuc@nvidia.com> Update CI Signed-off-by: Jiayu Chang <jiayuc@nvidia.com> Moved files Signed-off-by: Jiayu Chang <jiayuc@nvidia.com>
1 parent 2695d70 commit 1d61b14

File tree

24 files changed

+3343
-137
lines changed

24 files changed

+3343
-137
lines changed

cpp/include/tensorrt_llm/batch_manager/peftCacheManager.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ class PeftCacheManager : public BasePeftCacheManager
115115

116116
[[nodiscard]] bool isTaskDoneDevice(uint64_t taskId) const;
117117

118+
[[nodiscard]] bool isTaskCachedDevice(uint64_t const taskId) const;
119+
118120
void resetDeviceCache() override;
119121

120122
void markRequestDone(LlmRequest const& llmReq, bool pause = false) override;

cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -462,10 +462,8 @@ PeftCacheManager::PeftTable PeftCacheManager::ensureBatch(
462462
{
463463
auto&& f = ensureFutures.at(taskId);
464464
auto const values = f.get();
465-
for (auto const& reqId : reqIds)
466-
{
467-
peftTable.try_emplace(reqId, values);
468-
}
465+
// Map task_id to layer-module-configs instead of request_id to layer-module-configs
466+
peftTable.try_emplace(taskId, values);
469467
}
470468
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
471469
return peftTable;
@@ -486,6 +484,11 @@ bool PeftCacheManager::isTaskDoneDevice(uint64_t taskId) const
486484
return mDeviceLoraCache->isDone(taskId);
487485
}
488486

487+
bool PeftCacheManager::isTaskCachedDevice(uint64_t const taskId) const
488+
{
489+
return mDeviceLoraCache->has(taskId);
490+
}
491+
489492
void PeftCacheManager::updateTaskState(uint64_t taskId, uint64_t reqId, bool terminate, bool pause)
490493
{
491494
if (!terminate)
@@ -645,3 +648,5 @@ SizeType32 NoOpPeftCacheManager::determineNumPages(std::shared_ptr<LlmRequest> l
645648
return 0;
646649
}
647650
} // namespace tensorrt_llm::batch_manager
651+
652+
// TODO: merge C++ LoRA caching status with Py Slot manager

cpp/tensorrt_llm/kernels/cuda_graph_grouped_gemm.cu

Lines changed: 382 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
19+
#include "cutlass/gemm_coord.h"
20+
#include <NvInferRuntime.h>
21+
#include <cuda_runtime.h>
22+
23+
namespace tensorrt_llm
24+
{
25+
namespace kernels
26+
{
27+
28+
/**
29+
* @brief CUDA Graph compatible wrapper for grouped GEMM operations.
30+
*
31+
* This function accepts GPU pointers directly without any workspace for parameters,
32+
* making it fully compatible with CUDA Graph capture and replay.
33+
*
34+
* @param problem_sizes_ptr GPU pointer to array of cutlass::gemm::GemmCoord
35+
* @param problem_count Number of GEMM problems
36+
* @param ptrA_gpu GPU pointer to array of A matrix pointers
37+
* @param ptrB_gpu GPU pointer to array of B matrix pointers
38+
* @param ptrC_gpu GPU pointer to array of C matrix pointers (can be nullptr)
39+
* @param ptrD_gpu GPU pointer to array of D matrix pointers
40+
* @param isLoraIn Whether this is for LoRA input transformation
41+
* @param dataType Data type of the matrices
42+
* @param minKN Minimum K*N value for kernel selection
43+
* @param stream CUDA stream
44+
*/
45+
void cuda_graph_grouped_gemm(cutlass::gemm::GemmCoord* problem_sizes_ptr, int problem_count, void** ptrA_gpu,
46+
void** ptrB_gpu, void** ptrC_gpu, void** ptrD_gpu, int64_t* lda_gpu, int64_t* ldb_gpu, int64_t* ldc_gpu,
47+
int64_t* ldd_gpu, bool isLoraIn, nvinfer1::DataType dataType, int minKN,
48+
cutlass::gemm::GemmCoord* host_max_problem_sizes_ptr, cudaStream_t stream);
49+
50+
/**
51+
* @brief CUDA Graph compatible wrapper for split-K grouped GEMM operations.
52+
*
53+
* Similar to cuda_graph_grouped_gemm but uses split-K algorithm for better
54+
* performance with certain problem sizes. No parameter workspace needed.
55+
*/
56+
void cuda_graph_splitk_grouped_gemm(cutlass::gemm::GemmCoord* problem_sizes_ptr, int problem_count, void** ptrA_gpu,
57+
void** ptrB_gpu, void** ptrC_gpu, void** ptrD_gpu, int64_t* lda_gpu, int64_t* ldb_gpu, int64_t* ldc_gpu,
58+
int64_t* ldd_gpu, bool isLoraIn, nvinfer1::DataType dataType, int splitKSlices, int minKN,
59+
cutlass::gemm::GemmCoord* host_max_problem_sizes_ptr, int64_t* splitk_offsets_gpu, cudaStream_t stream);
60+
61+
} // namespace kernels
62+
} // namespace tensorrt_llm

cpp/tensorrt_llm/kernels/lora/lora.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ int LoraImpl::run(int64_t numTokens, int64_t numReqs, void const* input, int32_t
296296
+ (loraModuleIdx * numTokens * mMaxLowRank + handled_token_num * mMaxLowRank) * typeSize));
297297

298298
auto const N2 = mOutHiddenSizes[loraModuleIdx];
299-
cutlass::gemm::GemmCoord problem_2(M, N2, N);
299+
cutlass::gemm::GemmCoord problem_2(M, N2, N); // token_num, module_output_size, lora_rank
300300
problem_sizes_2.push_back(problem_2);
301301
ptrA_2.push_back(static_cast<void*>(static_cast<char*>(lowRankWorkSpace)
302302
+ (loraModuleIdx * numTokens * mMaxLowRank + handled_token_num * mMaxLowRank) * typeSize));

0 commit comments

Comments
 (0)