Skip to content

Commit 6f07fa8

Browse files
authored
[TRTLLM-7738][feat] Adding implementation of KVCacheManagerV2 (#10736)
Signed-off-by: Yao Yao <lowsfer@users.noreply.github.com> KVCacheManagerV2 is a new python-based implementation of the KV cache manager, featuring cleaner API, better abstraction and better code quality without the accumulated legacy.
1 parent 9fcc93e commit 6f07fa8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+9442
-17
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ tensorrt_llm/pg_utils_bindings.*.so
5252
tensorrt_llm/flash_mla/
5353
tensorrt_llm/flash_mla_cpp_tllm.*.so
5454
tensorrt_llm/flash_mla_cpp_tllm.pyi
55+
tensorrt_llm/runtime/kv_cache_manager_v2/**/*.so
56+
**/*__mypyc*.so
5557
tensorrt_llm/scripts
5658
*docs/cpp_docs*
5759
*docs/source/_cpp_gen*

cpp/tensorrt_llm/batch_manager/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ set(SRCS
3636
kvCacheManager.cpp
3737
kvCacheEventManager.cpp
3838
kvCacheTransferManager.cpp
39+
kvCacheManagerV2Utils.cpp
40+
kvCacheManagerV2Utils.cu
3941
llmRequest.cpp
4042
logitsPostProcessor.cpp
4143
loraBuffers.cpp
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#include "tensorrt_llm/batch_manager/kvCacheManagerV2Utils.h"
19+
#include "tensorrt_llm/common/logger.h"
20+
#include <cassert>
21+
#include <cstdio>
22+
#include <cuda.h>
23+
#include <fcntl.h>
24+
#include <memory>
25+
#include <unistd.h>
26+
#include <vector>
27+
28+
namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
29+
{
30+
31+
template <typename Func>
32+
bool loopedReadWrite(Func&& func, ssize_t size) noexcept
33+
{
34+
ssize_t count = 0;
35+
while (count < size)
36+
{
37+
ssize_t bytes = func(count);
38+
if (bytes <= 0)
39+
{
40+
if (errno == EINTR)
41+
{
42+
continue; // Retry on interrupt
43+
}
44+
TLLM_LOG_ERROR("Disk read/write failed: %s\n", strerror(errno));
45+
return false;
46+
}
47+
count += bytes;
48+
}
49+
assert(count == size);
50+
return true;
51+
}
52+
53+
bool writeAll(int fd, ssize_t pos, void const* data, ssize_t size) noexcept
54+
{
55+
return loopedReadWrite([=](ssize_t finished)
56+
{ return pwrite(fd, static_cast<std::byte const*>(data) + finished, size - finished, pos + finished); },
57+
size);
58+
}
59+
60+
bool readAll(int fd, ssize_t pos, void* data, ssize_t size) noexcept
61+
{
62+
return loopedReadWrite([=](ssize_t finished)
63+
{ return pread(fd, static_cast<std::byte*>(data) + finished, size - finished, pos + finished); },
64+
size);
65+
}
66+
67+
template <typename DstAddr, typename SrcAddr>
68+
struct UserData
69+
{
70+
std::vector<Task<DstAddr, SrcAddr>> tasks;
71+
ssize_t numBytes;
72+
};
73+
74+
CUDA_CB void hostFnDiskToDiskCopy(void* userData) noexcept
75+
{
76+
// @TODO: enable multi-threading with a thread pool
77+
using Data = UserData<DiskAddress, DiskAddress>;
78+
auto const data = std::unique_ptr<Data>(static_cast<Data*>(userData));
79+
std::vector<std::byte> buffer(data->numBytes);
80+
bool success = true;
81+
for (auto const& t : data->tasks)
82+
{
83+
success = success && readAll(t.src.fd, t.src.pos, buffer.data(), data->numBytes);
84+
success = success && writeAll(t.dst.fd, t.dst.pos, buffer.data(), data->numBytes);
85+
}
86+
if (!success)
87+
{
88+
TLLM_LOG_ERROR("[kvCacheManagerV2Utils] hostFnDiskToDiskCopy failed.\n");
89+
}
90+
}
91+
92+
CUDA_CB void hostFnDiskToHostCopy(void* userData) noexcept
93+
{
94+
// @TODO: enable multi-threading with a thread pool
95+
using Data = UserData<MemAddress, DiskAddress>;
96+
auto const data = std::unique_ptr<Data>(static_cast<Data*>(userData));
97+
bool success = true;
98+
for (auto const& t : data->tasks)
99+
{
100+
success = success && readAll(t.src.fd, t.src.pos, reinterpret_cast<void*>(t.dst), data->numBytes);
101+
}
102+
if (!success)
103+
{
104+
TLLM_LOG_ERROR("[kvCacheManagerV2Utils] hostFnDiskToHostCopy failed.\n");
105+
}
106+
}
107+
108+
CUDA_CB void hostFnHostToDiskCopy(void* userData) noexcept
109+
{
110+
// @TODO: enable multi-threading with a thread pool
111+
using Data = UserData<DiskAddress, MemAddress>;
112+
auto const data = std::unique_ptr<Data>(static_cast<Data*>(userData));
113+
bool success = true;
114+
for (auto const& t : data->tasks)
115+
{
116+
success = success && writeAll(t.dst.fd, t.dst.pos, reinterpret_cast<void const*>(t.src), data->numBytes);
117+
}
118+
if (!success)
119+
{
120+
TLLM_LOG_ERROR("[kvCacheManagerV2Utils] hostFnHostToDiskCopy failed.\n");
121+
}
122+
}
123+
124+
CUDA_CB void hostFnHostToHostCopy(void* userData) noexcept
125+
{
126+
// @TODO: enable multi-threading with a thread pool
127+
using Data = UserData<MemAddress, MemAddress>;
128+
auto const data = std::unique_ptr<Data>(static_cast<Data*>(userData));
129+
for (auto const& t : data->tasks)
130+
{
131+
memcpy(reinterpret_cast<void*>(t.dst), reinterpret_cast<void const*>(t.src), data->numBytes);
132+
}
133+
}
134+
135+
CUresult copyDiskToDisk(std::vector<Task<DiskAddress, DiskAddress>> tasks, ssize_t numBytes, CUstream stream) noexcept
136+
{
137+
using Data = UserData<DiskAddress, DiskAddress>;
138+
auto data = std::make_unique<Data>(Data{std::move(tasks), numBytes});
139+
return cuLaunchHostFunc(stream, hostFnDiskToDiskCopy, data.release());
140+
}
141+
142+
CUresult copyDiskToHost(std::vector<Task<MemAddress, DiskAddress>> tasks, ssize_t numBytes, CUstream stream) noexcept
143+
{
144+
using Data = UserData<MemAddress, DiskAddress>;
145+
auto data = std::make_unique<Data>(Data{std::move(tasks), numBytes});
146+
return cuLaunchHostFunc(stream, hostFnDiskToHostCopy, data.release());
147+
}
148+
149+
CUresult copyHostToDisk(std::vector<Task<DiskAddress, MemAddress>> tasks, ssize_t numBytes, CUstream stream) noexcept
150+
{
151+
using Data = UserData<DiskAddress, MemAddress>;
152+
auto data = std::make_unique<Data>(Data{std::move(tasks), numBytes});
153+
return cuLaunchHostFunc(stream, hostFnHostToDiskCopy, data.release());
154+
}
155+
156+
CUresult copyHostToHost(std::vector<Task<MemAddress, MemAddress>> tasks, ssize_t numBytes, CUstream stream) noexcept
157+
{
158+
using Data = UserData<MemAddress, MemAddress>;
159+
auto data = std::make_unique<Data>(Data{std::move(tasks), numBytes});
160+
return cuLaunchHostFunc(stream, hostFnHostToHostCopy, data.release());
161+
}
162+
163+
} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#include "kvCacheManagerV2Utils.h"
19+
#include "tensorrt_llm/common/assert.h"
20+
#include "tensorrt_llm/common/cudaUtils.h"
21+
#include <algorithm>
22+
#include <array>
23+
#include <cassert>
24+
#include <cuda_runtime.h>
25+
26+
namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
27+
{
28+
using Grain = uint4;
29+
constexpr uint32_t ctaSize = 128;
30+
constexpr uint32_t nbBufs = 4;
31+
constexpr uint32_t grainBytes = sizeof(Grain);
32+
33+
using MMTask = Task<MemAddress, MemAddress>;
34+
35+
__device__ __host__ inline uint32_t divUp(uint32_t a, uint32_t b)
36+
{
37+
return (a + b - 1) / b;
38+
}
39+
40+
template <uint32_t N>
41+
__global__ void batchedCopy(std::array<MMTask, N> const __grid_constant__ tasks, uint32_t nbBytes)
42+
{
43+
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
44+
asm volatile("griddepcontrol.launch_dependents;\n");
45+
#endif
46+
assert(nbBytes % sizeof(Grain) == 0);
47+
__shared__ Grain data[nbBufs][ctaSize];
48+
49+
uint32_t const nbTasks = gridDim.y;
50+
assert(nbTasks <= N);
51+
auto const& task = tasks[blockIdx.y];
52+
uint32_t const nbSplits = gridDim.x;
53+
uint32_t const idxSplit = blockIdx.x;
54+
uint32_t const tid = threadIdx.x;
55+
56+
constexpr uint32_t bytesPerIter = grainBytes * ctaSize;
57+
58+
uint32_t const totalIters = divUp(nbBytes, bytesPerIter);
59+
uint32_t const maxItersPerCta = divUp(totalIters, nbSplits);
60+
uint32_t const idxGrainBeg = ctaSize * maxItersPerCta * idxSplit + tid;
61+
uint32_t const idxGrainEnd = std::min(idxGrainBeg + ctaSize * maxItersPerCta, nbBytes / grainBytes);
62+
63+
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
64+
asm volatile("griddepcontrol.wait;\n");
65+
#endif
66+
for (uint32_t i = 0; i < maxItersPerCta + nbBufs; i++)
67+
{
68+
uint32_t const idxBuf = i % nbBufs;
69+
if (i >= nbBufs)
70+
{
71+
uint32_t const stIter = i - nbBufs;
72+
assert(idxBuf == (stIter % nbBufs));
73+
Grain const& src = data[idxBuf][tid];
74+
uint32_t const idxGrain = idxGrainBeg + ctaSize * stIter;
75+
Grain& dst = reinterpret_cast<Grain*>(task.dst)[idxGrain];
76+
asm volatile("cp.async.wait_group %0;\n" ::"n"(nbBufs - 1) : "memory");
77+
if (idxGrain < idxGrainEnd)
78+
{
79+
dst = src;
80+
}
81+
}
82+
uint32_t const ldIter = i;
83+
Grain* const dst = &data[idxBuf][tid];
84+
uint32_t const idxGrain = idxGrainBeg + ctaSize * ldIter;
85+
Grain const* const src = &reinterpret_cast<Grain const*>(task.src)[idxGrain];
86+
if (idxGrain < idxGrainEnd)
87+
{
88+
uint32_t const size = grainBytes;
89+
asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"l"(__cvta_generic_to_shared(dst)),
90+
"l"(src), "n"(grainBytes), "r"(size)
91+
: "memory");
92+
}
93+
asm volatile("cp.async.commit_group;\n" : : : "memory");
94+
}
95+
}
96+
97+
template <uint32_t N>
98+
CUresult launchBatchedCopyImpl(
99+
bool lowBandwidth, MMTask const* tasks, uint32_t nbTasks, uint32_t nbBytes, cudaStream_t stream)
100+
{
101+
TLLM_CHECK(nbTasks <= N);
102+
TLLM_CHECK_WITH_INFO(
103+
nbBytes % sizeof(Grain) == 0, "Not implemented case: nbBytes = %d must be a multiple of 16.", nbBytes);
104+
std::array<MMTask, N> const* pTasks;
105+
std::array<MMTask, N> tmp;
106+
if (nbTasks < N)
107+
{
108+
std::copy_n(tasks, nbTasks, tmp.begin());
109+
pTasks = &tmp;
110+
}
111+
else
112+
{
113+
pTasks = reinterpret_cast<std::array<MMTask, N> const*>(tasks);
114+
}
115+
uint32_t const nbSplits = lowBandwidth ? 1 : divUp(nbBytes, grainBytes * ctaSize * 2);
116+
void* args[] = {(void*) pTasks, (void*) &nbBytes};
117+
static CUkernel const kernel = [] -> CUkernel
118+
{
119+
cudaKernel_t kernel = nullptr;
120+
TLLM_CUDA_CHECK(cudaGetKernel(&kernel, reinterpret_cast<void const*>(&batchedCopy<N>)));
121+
return kernel;
122+
}();
123+
return common::CUDADriverWrapper::getInstance()->cuLaunchKernel(reinterpret_cast<CUfunction>(kernel), nbSplits,
124+
nbTasks, 1, // gridDimX, gridDimY, gridDimZ
125+
ctaSize, 1, 1, // blockDimX, blockDimY, blockDimZ
126+
0, // sharedMemBytes
127+
stream, args, nullptr);
128+
}
129+
130+
// When bandwidth is low, e.g. when host memory is involved, we avoid splitting as fewer CTAs should be enough to
131+
// saturate the bandwidth.
132+
CUresult launchBatchedCopy(bool lowBandwidth, std::vector<MMTask> const& tasks, uint32_t nbBytes, cudaStream_t stream)
133+
{
134+
constexpr uint32_t maxN = 256;
135+
uint32_t const nbWholeBatches = tasks.size() / maxN;
136+
for (uint32_t i = 0; i < nbWholeBatches; i++)
137+
{
138+
CUresult const err = launchBatchedCopyImpl<maxN>(lowBandwidth, tasks.data() + maxN * i, maxN, nbBytes, stream);
139+
if (err != CUDA_SUCCESS)
140+
{
141+
return err;
142+
}
143+
}
144+
{
145+
auto const* const pTasks = tasks.data() + maxN * nbWholeBatches;
146+
auto const batchSize = tasks.size() % maxN;
147+
if (batchSize == 0)
148+
{
149+
return CUDA_SUCCESS;
150+
}
151+
if (batchSize > maxN / 2)
152+
{
153+
return launchBatchedCopyImpl<maxN>(lowBandwidth, pTasks, batchSize, nbBytes, stream);
154+
}
155+
if (batchSize > maxN / 4)
156+
{
157+
return launchBatchedCopyImpl<maxN / 2>(lowBandwidth, pTasks, batchSize, nbBytes, stream);
158+
}
159+
if (batchSize > maxN / 8)
160+
{
161+
return launchBatchedCopyImpl<maxN / 4>(lowBandwidth, pTasks, batchSize, nbBytes, stream);
162+
}
163+
return launchBatchedCopyImpl<maxN / 8>(lowBandwidth, pTasks, batchSize, nbBytes, stream);
164+
}
165+
}
166+
167+
CUresult copyHostToDevice(std::vector<MMTask> const& tasks, ssize_t numBytes, CUstream stream) noexcept
168+
{
169+
return launchBatchedCopy(true, tasks, numBytes, stream);
170+
}
171+
172+
CUresult copyDeviceToHost(std::vector<MMTask> const& tasks, ssize_t numBytes, CUstream stream) noexcept
173+
{
174+
return launchBatchedCopy(true, tasks, numBytes, stream);
175+
}
176+
177+
CUresult copyDeviceToDevice(std::vector<MMTask> const& tasks, ssize_t numBytes, CUstream stream) noexcept
178+
{
179+
return launchBatchedCopy(false, tasks, numBytes, stream);
180+
}
181+
182+
} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2

0 commit comments

Comments
 (0)