|
| 1 | +/* |
| 2 | + * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 3 | + * SPDX-License-Identifier: Apache-2.0 |
| 4 | + * |
| 5 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | + * you may not use this file except in compliance with the License. |
| 7 | + * You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + */ |
| 17 | + |
| 18 | +#include "kvCacheManagerV2Utils.h" |
| 19 | +#include "tensorrt_llm/common/assert.h" |
| 20 | +#include "tensorrt_llm/common/cudaUtils.h" |
| 21 | +#include <algorithm> |
| 22 | +#include <array> |
| 23 | +#include <cassert> |
| 24 | +#include <cuda_runtime.h> |
| 25 | + |
| 26 | +namespace tensorrt_llm::batch_manager::kv_cache_manager_v2 |
| 27 | +{ |
| 28 | +using Grain = uint4; |
| 29 | +constexpr uint32_t ctaSize = 128; |
| 30 | +constexpr uint32_t nbBufs = 4; |
| 31 | +constexpr uint32_t grainBytes = sizeof(Grain); |
| 32 | + |
| 33 | +using MMTask = Task<MemAddress, MemAddress>; |
| 34 | + |
| 35 | +__device__ __host__ inline uint32_t divUp(uint32_t a, uint32_t b) |
| 36 | +{ |
| 37 | + return (a + b - 1) / b; |
| 38 | +} |
| 39 | + |
| 40 | +template <uint32_t N> |
| 41 | +__global__ void batchedCopy(std::array<MMTask, N> const __grid_constant__ tasks, uint32_t nbBytes) |
| 42 | +{ |
| 43 | +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) |
| 44 | + asm volatile("griddepcontrol.launch_dependents;\n"); |
| 45 | +#endif |
| 46 | + assert(nbBytes % sizeof(Grain) == 0); |
| 47 | + __shared__ Grain data[nbBufs][ctaSize]; |
| 48 | + |
| 49 | + uint32_t const nbTasks = gridDim.y; |
| 50 | + assert(nbTasks <= N); |
| 51 | + auto const& task = tasks[blockIdx.y]; |
| 52 | + uint32_t const nbSplits = gridDim.x; |
| 53 | + uint32_t const idxSplit = blockIdx.x; |
| 54 | + uint32_t const tid = threadIdx.x; |
| 55 | + |
| 56 | + constexpr uint32_t bytesPerIter = grainBytes * ctaSize; |
| 57 | + |
| 58 | + uint32_t const totalIters = divUp(nbBytes, bytesPerIter); |
| 59 | + uint32_t const maxItersPerCta = divUp(totalIters, nbSplits); |
| 60 | + uint32_t const idxGrainBeg = ctaSize * maxItersPerCta * idxSplit + tid; |
| 61 | + uint32_t const idxGrainEnd = std::min(idxGrainBeg + ctaSize * maxItersPerCta, nbBytes / grainBytes); |
| 62 | + |
| 63 | +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) |
| 64 | + asm volatile("griddepcontrol.wait;\n"); |
| 65 | +#endif |
| 66 | + for (uint32_t i = 0; i < maxItersPerCta + nbBufs; i++) |
| 67 | + { |
| 68 | + uint32_t const idxBuf = i % nbBufs; |
| 69 | + if (i >= nbBufs) |
| 70 | + { |
| 71 | + uint32_t const stIter = i - nbBufs; |
| 72 | + assert(idxBuf == (stIter % nbBufs)); |
| 73 | + Grain const& src = data[idxBuf][tid]; |
| 74 | + uint32_t const idxGrain = idxGrainBeg + ctaSize * stIter; |
| 75 | + Grain& dst = reinterpret_cast<Grain*>(task.dst)[idxGrain]; |
| 76 | + asm volatile("cp.async.wait_group %0;\n" ::"n"(nbBufs - 1) : "memory"); |
| 77 | + if (idxGrain < idxGrainEnd) |
| 78 | + { |
| 79 | + dst = src; |
| 80 | + } |
| 81 | + } |
| 82 | + uint32_t const ldIter = i; |
| 83 | + Grain* const dst = &data[idxBuf][tid]; |
| 84 | + uint32_t const idxGrain = idxGrainBeg + ctaSize * ldIter; |
| 85 | + Grain const* const src = &reinterpret_cast<Grain const*>(task.src)[idxGrain]; |
| 86 | + if (idxGrain < idxGrainEnd) |
| 87 | + { |
| 88 | + uint32_t const size = grainBytes; |
| 89 | + asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"l"(__cvta_generic_to_shared(dst)), |
| 90 | + "l"(src), "n"(grainBytes), "r"(size) |
| 91 | + : "memory"); |
| 92 | + } |
| 93 | + asm volatile("cp.async.commit_group;\n" : : : "memory"); |
| 94 | + } |
| 95 | +} |
| 96 | + |
| 97 | +template <uint32_t N> |
| 98 | +CUresult launchBatchedCopyImpl( |
| 99 | + bool lowBandwidth, MMTask const* tasks, uint32_t nbTasks, uint32_t nbBytes, cudaStream_t stream) |
| 100 | +{ |
| 101 | + TLLM_CHECK(nbTasks <= N); |
| 102 | + TLLM_CHECK_WITH_INFO( |
| 103 | + nbBytes % sizeof(Grain) == 0, "Not implemented case: nbBytes = %d must be a multiple of 16.", nbBytes); |
| 104 | + std::array<MMTask, N> const* pTasks; |
| 105 | + std::array<MMTask, N> tmp; |
| 106 | + if (nbTasks < N) |
| 107 | + { |
| 108 | + std::copy_n(tasks, nbTasks, tmp.begin()); |
| 109 | + pTasks = &tmp; |
| 110 | + } |
| 111 | + else |
| 112 | + { |
| 113 | + pTasks = reinterpret_cast<std::array<MMTask, N> const*>(tasks); |
| 114 | + } |
| 115 | + uint32_t const nbSplits = lowBandwidth ? 1 : divUp(nbBytes, grainBytes * ctaSize * 2); |
| 116 | + void* args[] = {(void*) pTasks, (void*) &nbBytes}; |
| 117 | + static CUkernel const kernel = [] -> CUkernel |
| 118 | + { |
| 119 | + cudaKernel_t kernel = nullptr; |
| 120 | + TLLM_CUDA_CHECK(cudaGetKernel(&kernel, reinterpret_cast<void const*>(&batchedCopy<N>))); |
| 121 | + return kernel; |
| 122 | + }(); |
| 123 | + return common::CUDADriverWrapper::getInstance()->cuLaunchKernel(reinterpret_cast<CUfunction>(kernel), nbSplits, |
| 124 | + nbTasks, 1, // gridDimX, gridDimY, gridDimZ |
| 125 | + ctaSize, 1, 1, // blockDimX, blockDimY, blockDimZ |
| 126 | + 0, // sharedMemBytes |
| 127 | + stream, args, nullptr); |
| 128 | +} |
| 129 | + |
| 130 | +// When bandwidth is low, e.g. when host memory is involved, we avoid splitting as fewer CTAs should be enough to |
| 131 | +// saturate the bandwidth. |
| 132 | +CUresult launchBatchedCopy(bool lowBandwidth, std::vector<MMTask> const& tasks, uint32_t nbBytes, cudaStream_t stream) |
| 133 | +{ |
| 134 | + constexpr uint32_t maxN = 256; |
| 135 | + uint32_t const nbWholeBatches = tasks.size() / maxN; |
| 136 | + for (uint32_t i = 0; i < nbWholeBatches; i++) |
| 137 | + { |
| 138 | + CUresult const err = launchBatchedCopyImpl<maxN>(lowBandwidth, tasks.data() + maxN * i, maxN, nbBytes, stream); |
| 139 | + if (err != CUDA_SUCCESS) |
| 140 | + { |
| 141 | + return err; |
| 142 | + } |
| 143 | + } |
| 144 | + { |
| 145 | + auto const* const pTasks = tasks.data() + maxN * nbWholeBatches; |
| 146 | + auto const batchSize = tasks.size() % maxN; |
| 147 | + if (batchSize == 0) |
| 148 | + { |
| 149 | + return CUDA_SUCCESS; |
| 150 | + } |
| 151 | + if (batchSize > maxN / 2) |
| 152 | + { |
| 153 | + return launchBatchedCopyImpl<maxN>(lowBandwidth, pTasks, batchSize, nbBytes, stream); |
| 154 | + } |
| 155 | + if (batchSize > maxN / 4) |
| 156 | + { |
| 157 | + return launchBatchedCopyImpl<maxN / 2>(lowBandwidth, pTasks, batchSize, nbBytes, stream); |
| 158 | + } |
| 159 | + if (batchSize > maxN / 8) |
| 160 | + { |
| 161 | + return launchBatchedCopyImpl<maxN / 4>(lowBandwidth, pTasks, batchSize, nbBytes, stream); |
| 162 | + } |
| 163 | + return launchBatchedCopyImpl<maxN / 8>(lowBandwidth, pTasks, batchSize, nbBytes, stream); |
| 164 | + } |
| 165 | +} |
| 166 | + |
| 167 | +CUresult copyHostToDevice(std::vector<MMTask> const& tasks, ssize_t numBytes, CUstream stream) noexcept |
| 168 | +{ |
| 169 | + return launchBatchedCopy(true, tasks, numBytes, stream); |
| 170 | +} |
| 171 | + |
| 172 | +CUresult copyDeviceToHost(std::vector<MMTask> const& tasks, ssize_t numBytes, CUstream stream) noexcept |
| 173 | +{ |
| 174 | + return launchBatchedCopy(true, tasks, numBytes, stream); |
| 175 | +} |
| 176 | + |
| 177 | +CUresult copyDeviceToDevice(std::vector<MMTask> const& tasks, ssize_t numBytes, CUstream stream) noexcept |
| 178 | +{ |
| 179 | + return launchBatchedCopy(false, tasks, numBytes, stream); |
| 180 | +} |
| 181 | + |
| 182 | +} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2 |
0 commit comments