|
| 1 | +/* |
| 2 | + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +#include "cutlass_extensions/gemm_configs.h" |
| 18 | + |
| 19 | +#include "tensorrt_llm/common/cudaUtils.h" |
| 20 | +#include "tensorrt_llm/kernels/cutlass_kernels/include/allreduce_gemm_runner.h" |
| 21 | +#include "tensorrt_llm/runtime/ipcNvlsMemory.h" |
| 22 | +#include "tensorrt_llm/thop/thUtils.h" |
| 23 | + |
| 24 | +#include <ATen/cuda/EmptyTensor.h> |
| 25 | + |
| 26 | +#include <cstddef> |
| 27 | +#include <cuda_fp16.h> |
| 28 | + |
| 29 | +#include <cstdint> |
| 30 | +#include <functional> |
| 31 | +#include <type_traits> |
| 32 | +#include <vector> |
| 33 | + |
| 34 | +using tensorrt_llm::kernels::opened_cutlass_kernels::GemmAllReduceImplRunner; |
| 35 | +using tensorrt_llm::kernels::opened_cutlass_kernels::GemmAllReduceImplInterface; |
| 36 | +using tensorrt_llm::kernels::opened_cutlass_kernels::GemmTypes; |
| 37 | +using tensorrt_llm::kernels::opened_cutlass_kernels::PersistentWorkspaceInterface; |
| 38 | + |
| 39 | +namespace |
| 40 | +{ |
| 41 | +struct AllocationKey |
| 42 | +{ |
| 43 | + int64_t device_index; |
| 44 | + std::set<int> group; |
| 45 | + |
| 46 | + bool operator==(AllocationKey const& other) const |
| 47 | + { |
| 48 | + return device_index == other.device_index && group == other.group; |
| 49 | + } |
| 50 | + |
| 51 | + std::string toString() const |
| 52 | + { |
| 53 | + std::stringstream ss; |
| 54 | + ss << "AllocationKey(device: " << device_index << ", group: ["; |
| 55 | + for (int rank : group) |
| 56 | + { |
| 57 | + ss << rank << ", "; |
| 58 | + } |
| 59 | + ss << "])"; |
| 60 | + return ss.str(); |
| 61 | + } |
| 62 | +}; |
| 63 | + |
| 64 | +struct AllocationKeyHash |
| 65 | +{ |
| 66 | + size_t operator()(AllocationKey const& key) const |
| 67 | + { |
| 68 | + size_t seed = 0; |
| 69 | + |
| 70 | + // Hash the device index |
| 71 | + hash_combine(seed, key.device_index); |
| 72 | + |
| 73 | + // Hash the set elements |
| 74 | + for (auto const& elem : key.group) |
| 75 | + { |
| 76 | + hash_combine(seed, elem); |
| 77 | + } |
| 78 | + |
| 79 | + return seed; |
| 80 | + } |
| 81 | + |
| 82 | +private: |
| 83 | + template <typename T> |
| 84 | + static void hash_combine(size_t& seed, T const& val) |
| 85 | + { |
| 86 | + seed ^= std::hash<T>()(val) + 0x9e3779b9 + (seed << 6) + (seed >> 2); |
| 87 | + } |
| 88 | +}; |
| 89 | + |
| 90 | +class IpcNvlsHandleWrapper |
| 91 | +{ |
| 92 | +public: |
| 93 | + IpcNvlsHandleWrapper(size_t size, std::set<int> groups) |
| 94 | + : mSize(size) |
| 95 | + { |
| 96 | + mHandle = tensorrt_llm::runtime::ipcNvlsAllocate(size, groups); |
| 97 | + } |
| 98 | + |
| 99 | + tensorrt_llm::runtime::IpcNvlsHandle* getHandle() const |
| 100 | + { |
| 101 | + return mHandle; |
| 102 | + } |
| 103 | + |
| 104 | + size_t getSize() const |
| 105 | + { |
| 106 | + return mSize; |
| 107 | + } |
| 108 | + |
| 109 | + ~IpcNvlsHandleWrapper() |
| 110 | + { |
| 111 | + tensorrt_llm::runtime::ipcNvlsFree(mHandle); |
| 112 | + } |
| 113 | + |
| 114 | +private: |
| 115 | + size_t mSize; |
| 116 | + tensorrt_llm::runtime::IpcNvlsHandle* mHandle; |
| 117 | +}; |
| 118 | + |
| 119 | +std::once_flag init_flag; |
| 120 | + |
| 121 | +size_t getPreferredWorkspaceSize() |
| 122 | +{ |
| 123 | + // 128MB |
| 124 | + static size_t preferredWorkspaceSize = 134217728; |
| 125 | + std::call_once(init_flag, |
| 126 | + [&]() |
| 127 | + { |
| 128 | + char const* envWorkspaceSize = std::getenv("TRTLLM_GEMM_ALLREDUCE_WORKSPACE_SIZE"); |
| 129 | + size_t workspaceSize = 0; |
| 130 | + if (envWorkspaceSize != nullptr) |
| 131 | + { |
| 132 | + workspaceSize = std::atoi(envWorkspaceSize); |
| 133 | + } |
| 134 | + preferredWorkspaceSize = std::max(preferredWorkspaceSize, workspaceSize); |
| 135 | + }); |
| 136 | + return preferredWorkspaceSize; |
| 137 | +} |
| 138 | + |
| 139 | +class GemmAllreduceNvlsMemoryManager |
| 140 | +{ |
| 141 | +public: |
| 142 | + GemmAllreduceNvlsMemoryManager() |
| 143 | + { |
| 144 | + TLLM_LOG_INFO("GemmAllreduceNvlsMemoryManager constructor"); |
| 145 | + } |
| 146 | + |
| 147 | + ~GemmAllreduceNvlsMemoryManager() |
| 148 | + { |
| 149 | + TLLM_LOG_INFO("GemmAllreduceNvlsMemoryManager destructor"); |
| 150 | + } |
| 151 | + |
| 152 | + std::pair<PersistentWorkspaceInterface*, tensorrt_llm::runtime::IpcNvlsHandle*> getWorkspace( |
| 153 | + GemmAllReduceImplInterface* runner, GemmAllReduceImplInterface::ProblemArgs const& problem, |
| 154 | + AllocationKey const& key) |
| 155 | + { |
| 156 | + int M = std::get<0>(problem.problem_size); |
| 157 | + int N = std::get<1>(problem.problem_size); |
| 158 | + size_t requiredSize = M * N * 2; |
| 159 | + size_t preferredWorkspaceSize = getPreferredWorkspaceSize(); |
| 160 | + if (requiredSize > preferredWorkspaceSize) |
| 161 | + { |
| 162 | + std::stringstream ss; |
| 163 | + ss << "Please set TRTLLM_GEMM_ALLREDUCE_WORKSPACE_SIZE to at least " << requiredSize << " bytes"; |
| 164 | + C10_THROW_ERROR(ErrorAlwaysShowCppStacktrace, ss.str().c_str()); |
| 165 | + } |
| 166 | + |
| 167 | + auto handle = mHandles[key]; |
| 168 | + if (handle == nullptr) |
| 169 | + { |
| 170 | + TLLM_LOG_INFO("Creating allreduce workspace for %s", key.toString().c_str()); |
| 171 | + handle = std::make_shared<IpcNvlsHandleWrapper>(preferredWorkspaceSize, key.group); |
| 172 | + GemmAllReduceImplInterface::ProblemArgs tmpArgs; |
| 173 | + int maxN = 16384; |
| 174 | + int maxM = preferredWorkspaceSize / (maxN * 2); |
| 175 | + tmpArgs.argProblemShape(maxM, maxN, 512, 1) |
| 176 | + .argRanks(problem.rank, problem.ranks) |
| 177 | + .argLaunchConfig(runner->getSupportedLaunchConfigs()[0]); |
| 178 | + auto workspace = runner->getPersistentWorkspace(tmpArgs); |
| 179 | + workspace->allocate(); |
| 180 | + mWorkspaces[key] = workspace; |
| 181 | + mHandles[key] = handle; |
| 182 | + } |
| 183 | + return std::make_pair(mWorkspaces[key].get(), mHandles[key]->getHandle()); |
| 184 | + } |
| 185 | + |
| 186 | +private: |
| 187 | + std::unordered_map<AllocationKey, std::shared_ptr<PersistentWorkspaceInterface>, AllocationKeyHash> mWorkspaces; |
| 188 | + std::unordered_map<AllocationKey, std::shared_ptr<IpcNvlsHandleWrapper>, AllocationKeyHash> mHandles; |
| 189 | +}; |
| 190 | + |
| 191 | +GemmAllreduceNvlsMemoryManager* getGemmAllreduceNvlsMemoryManager() |
| 192 | +{ |
| 193 | + static GemmAllreduceNvlsMemoryManager gNvlsMemoryManager; |
| 194 | + return &gNvlsMemoryManager; |
| 195 | +} |
| 196 | + |
| 197 | +at::Tensor runGemmImpl(GemmAllReduceImplInterface* runner, GemmAllReduceImplInterface::ProblemArgs& problem, |
| 198 | + at::ScalarType outputDtype, c10::cuda::CUDAStream stream) |
| 199 | +{ |
| 200 | + AllocationKey key{stream.device_index(), problem.ranks}; |
| 201 | + auto [workspace, handle] = getGemmAllreduceNvlsMemoryManager()->getWorkspace(runner, problem, key); |
| 202 | + problem.argD((void*) handle->uc_ptr, (void*) handle->mc_ptr, (void**) handle->ipc_uc_ptrs.data()); |
| 203 | + problem.argWorkspace(workspace); |
| 204 | + runner->run(problem, stream); |
| 205 | + size_t dSize |
| 206 | + = std::get<0>(problem.problem_size) * std::get<1>(problem.problem_size) * c10::elementSize(outputDtype); |
| 207 | + auto D = at::detail::empty_cuda({std::get<0>(problem.problem_size), std::get<1>(problem.problem_size)}, outputDtype, |
| 208 | + stream.device(), std::nullopt); |
| 209 | + TLLM_CUDA_CHECK(cudaMemcpyAsync( |
| 210 | + D.data_ptr(), reinterpret_cast<void const*>(handle->uc_ptr), dSize, cudaMemcpyDeviceToDevice, stream)); |
| 211 | + return D; |
| 212 | +} |
| 213 | +} // namespace |
| 214 | + |
| 215 | +namespace torch_ext |
| 216 | +{ |
| 217 | + |
| 218 | +class Fp4GemmAllreduceRunner : public torch::CustomClassHolder |
| 219 | +{ |
| 220 | +public: |
| 221 | + explicit Fp4GemmAllreduceRunner(at::ScalarType outputDtype, int64_t rank, torch::List<int64_t> group) |
| 222 | + : mOutputDtype(outputDtype) |
| 223 | + , mRank(rank) |
| 224 | + { |
| 225 | + for (int64_t rank : group) |
| 226 | + { |
| 227 | + mGroup.insert(static_cast<int>(rank)); |
| 228 | + } |
| 229 | + |
| 230 | + if (outputDtype == at::ScalarType::Half) |
| 231 | + { |
| 232 | + using Traits = GemmTypes<cutlass::float_e2m1_t, cutlass::float_e2m1_t, cutlass::half_t, cutlass::half_t, |
| 233 | + cutlass::float_ue4m3_t, cutlass::float_ue4m3_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, |
| 234 | + cutlass::layout::RowMajor, cutlass::layout::RowMajor>; |
| 235 | + mRunner = std::make_shared<GemmAllReduceImplRunner<Traits>>(); |
| 236 | + } |
| 237 | + else if (outputDtype == at::ScalarType::BFloat16) |
| 238 | + { |
| 239 | + using Traits = GemmTypes<cutlass::float_e2m1_t, cutlass::float_e2m1_t, cutlass::bfloat16_t, |
| 240 | + cutlass::bfloat16_t, cutlass::float_ue4m3_t, cutlass::float_ue4m3_t, cutlass::layout::RowMajor, |
| 241 | + cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, cutlass::layout::RowMajor>; |
| 242 | + mRunner = std::make_shared<GemmAllReduceImplRunner<Traits>>(); |
| 243 | + } |
| 244 | + else |
| 245 | + { |
| 246 | + C10_THROW_ERROR(NotImplementedError, "Unsupported input or output dtype"); |
| 247 | + } |
| 248 | + |
| 249 | + mConfigs = mRunner->getSupportedLaunchConfigs(); |
| 250 | + } |
| 251 | + |
| 252 | + at::Tensor runGemm(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& mat1Scale, |
| 253 | + at::Tensor const& mat2Scale, at::Tensor const& alpha, int64_t configIdx) const |
| 254 | + { |
| 255 | + if (configIdx < 0) |
| 256 | + configIdx = 0; |
| 257 | + |
| 258 | + TORCH_CHECK(configIdx < int64_t(mConfigs.size()), "configIdx out of bounds"); |
| 259 | + const int64_t M = mat1.size(0); |
| 260 | + const int64_t N = mat2.size(0); |
| 261 | + const int64_t K = mat1.size(1) * 2; |
| 262 | + |
| 263 | + GemmAllReduceImplInterface::ProblemArgs problemArgs; |
| 264 | + problemArgs.argProblemShape(M, N, K, 1); |
| 265 | + problemArgs.argA(mat1.data_ptr()); |
| 266 | + problemArgs.argB(mat2.data_ptr()); |
| 267 | + problemArgs.argAScale(mat1Scale.data_ptr()); |
| 268 | + problemArgs.argBScale(mat2Scale.data_ptr()); |
| 269 | + problemArgs.argC(nullptr); |
| 270 | + problemArgs.argAlphaPtr(reinterpret_cast<float const*>(alpha.const_data_ptr())); |
| 271 | + problemArgs.argBeta(0.f); |
| 272 | + problemArgs.argRanks(mRank, mGroup); |
| 273 | + problemArgs.argLaunchConfig(mConfigs[configIdx]); |
| 274 | + |
| 275 | + auto stream = at::cuda::getCurrentCUDAStream(mat1.get_device()); |
| 276 | + return runGemmImpl(mRunner.get(), problemArgs, mOutputDtype, stream); |
| 277 | + } |
| 278 | + |
| 279 | + int64_t getNumConfigs() const |
| 280 | + { |
| 281 | + return static_cast<int64_t>(mConfigs.size()); |
| 282 | + } |
| 283 | + |
| 284 | +private: |
| 285 | + at::ScalarType mOutputDtype; |
| 286 | + int mRank; |
| 287 | + std::set<int> mGroup; |
| 288 | + std::shared_ptr<GemmAllReduceImplInterface> mRunner{nullptr}; |
| 289 | + std::vector<GemmAllReduceImplInterface::LaunchConfig> mConfigs; |
| 290 | +}; |
| 291 | + |
| 292 | +} // namespace torch_ext |
| 293 | + |
| 294 | +TORCH_LIBRARY_FRAGMENT(trtllm, m) |
| 295 | +{ |
| 296 | + m.class_<torch_ext::Fp4GemmAllreduceRunner>("Fp4GemmAllreduceRunner") |
| 297 | + .def(torch::init<at::ScalarType, int64_t, torch::List<int64_t>>()) |
| 298 | + .def("run_gemm", &torch_ext::Fp4GemmAllreduceRunner::runGemm) |
| 299 | + .def("get_num_configs", &torch_ext::Fp4GemmAllreduceRunner::getNumConfigs); |
| 300 | +} |
0 commit comments