|
| 1 | +/* |
| 2 | + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +#include "cutlass_extensions/gemm_configs.h" |
| 18 | + |
| 19 | +#include "tensorrt_llm/common/cudaUtils.h" |
| 20 | +#include "tensorrt_llm/kernels/cutlass_kernels/include/allreduce_gemm_runner.h" |
| 21 | +#include "tensorrt_llm/runtime/ipcNvlsMemory.h" |
| 22 | +#include "tensorrt_llm/thop/thUtils.h" |
| 23 | + |
| 24 | +#include <cstddef> |
| 25 | +#include <cuda_fp16.h> |
| 26 | + |
| 27 | +#include <cstdint> |
| 28 | +#include <functional> |
| 29 | +#include <type_traits> |
| 30 | +#include <vector> |
| 31 | + |
| 32 | +using tensorrt_llm::kernels::opened_cutlass_kernels::GemmAllReduceImplRunner; |
| 33 | +using tensorrt_llm::kernels::opened_cutlass_kernels::GemmAllReduceImplInterface; |
| 34 | +using tensorrt_llm::kernels::opened_cutlass_kernels::GemmTypes; |
| 35 | +using tensorrt_llm::kernels::opened_cutlass_kernels::PersistentWorkspaceInterface; |
| 36 | + |
| 37 | +namespace |
| 38 | +{ |
| 39 | +struct AllocationKey |
| 40 | +{ |
| 41 | + c10::cuda::CUDAStream stream; |
| 42 | + std::set<int> group; |
| 43 | + |
| 44 | + bool operator==(AllocationKey const& other) const |
| 45 | + { |
| 46 | + return stream.id() == other.stream.id() && stream.device_index() == other.stream.device_index() |
| 47 | + && group == other.group; |
| 48 | + } |
| 49 | + |
| 50 | + std::string toString() const |
| 51 | + { |
| 52 | + std::stringstream ss; |
| 53 | + ss << "AllocationKey(stream: " << stream.id() << ", device: " << (int) stream.device_index() << ", group: ["; |
| 54 | + for (int rank : group) |
| 55 | + { |
| 56 | + ss << rank << ", "; |
| 57 | + } |
| 58 | + ss << "])"; |
| 59 | + return ss.str(); |
| 60 | + } |
| 61 | +}; |
| 62 | + |
| 63 | +struct AllocationKeyHash |
| 64 | +{ |
| 65 | + size_t operator()(AllocationKey const& key) const |
| 66 | + { |
| 67 | + size_t seed = 0; |
| 68 | + |
| 69 | + // Hash the stream (using stream ID and device index) |
| 70 | + hash_combine(seed, key.stream.id()); |
| 71 | + hash_combine(seed, key.stream.device_index()); |
| 72 | + |
| 73 | + // Hash the set elements |
| 74 | + for (auto const& elem : key.group) |
| 75 | + { |
| 76 | + hash_combine(seed, elem); |
| 77 | + } |
| 78 | + |
| 79 | + return seed; |
| 80 | + } |
| 81 | + |
| 82 | +private: |
| 83 | + template <typename T> |
| 84 | + static void hash_combine(size_t& seed, T const& val) |
| 85 | + { |
| 86 | + seed ^= std::hash<T>()(val) + 0x9e3779b9 + (seed << 6) + (seed >> 2); |
| 87 | + } |
| 88 | +}; |
| 89 | + |
| 90 | +class IpcNvlsHandleWrapper |
| 91 | +{ |
| 92 | +public: |
| 93 | + IpcNvlsHandleWrapper(size_t size, std::set<int> groups) |
| 94 | + : mSize(size) |
| 95 | + { |
| 96 | + mHandle = tensorrt_llm::runtime::ipcNvlsAllocate(size, groups); |
| 97 | + } |
| 98 | + |
| 99 | + tensorrt_llm::runtime::IpcNvlsHandle* getHandle() const |
| 100 | + { |
| 101 | + return mHandle; |
| 102 | + } |
| 103 | + |
| 104 | + size_t getSize() const |
| 105 | + { |
| 106 | + return mSize; |
| 107 | + } |
| 108 | + |
| 109 | + ~IpcNvlsHandleWrapper() |
| 110 | + { |
| 111 | + tensorrt_llm::runtime::ipcNvlsFree(mHandle); |
| 112 | + } |
| 113 | + |
| 114 | +private: |
| 115 | + size_t mSize; |
| 116 | + tensorrt_llm::runtime::IpcNvlsHandle* mHandle; |
| 117 | +}; |
| 118 | + |
| 119 | +class NvlsMemoryManager |
| 120 | +{ |
| 121 | +public: |
| 122 | + PersistentWorkspaceInterface* getWorkspace(GemmAllReduceImplInterface* runner, |
| 123 | + GemmAllReduceImplInterface::ProblemArgs const& problem, AllocationKey const& key) |
| 124 | + { |
| 125 | + auto newWorkspace = runner->getPersistentWorkspace(problem); |
| 126 | + auto curWorkspace = mWorkspaces[key]; |
| 127 | + |
| 128 | + if (curWorkspace == nullptr || curWorkspace->size() < newWorkspace->size()) |
| 129 | + { |
| 130 | + TLLM_LOG_WARNING("NvlsMemoryManager allocate workspace: key=%s, size=%zu.", key.toString().c_str(), |
| 131 | + newWorkspace->size()); |
| 132 | + newWorkspace->allocate(); |
| 133 | + mWorkspaces[key] = newWorkspace; |
| 134 | + } |
| 135 | + return mWorkspaces[key].get(); |
| 136 | + } |
| 137 | + |
| 138 | + tensorrt_llm::runtime::IpcNvlsHandle* getD(size_t size, AllocationKey const& key) |
| 139 | + { |
| 140 | + constexpr size_t PREFER_SIZE = 1024 * 16384 * 2; |
| 141 | + auto handle = mHandles[key]; |
| 142 | + if (handle == nullptr || handle->getSize() < size) |
| 143 | + { |
| 144 | + if (size < PREFER_SIZE) |
| 145 | + { |
| 146 | + handle = std::make_shared<IpcNvlsHandleWrapper>(PREFER_SIZE, key.group); |
| 147 | + TLLM_LOG_WARNING( |
| 148 | + "NvlsMemoryManager allocate D: key=%s, size=%zu bytes.", key.toString().c_str(), PREFER_SIZE); |
| 149 | + } |
| 150 | + else |
| 151 | + { |
| 152 | + handle = std::make_shared<IpcNvlsHandleWrapper>(size, key.group); |
| 153 | + TLLM_LOG_WARNING("NvlsMemoryManager allocate D: key=%s, size=%zu bytes.", key.toString().c_str(), size); |
| 154 | + } |
| 155 | + mHandles[key] = handle; |
| 156 | + } |
| 157 | + |
| 158 | + return mHandles[key]->getHandle(); |
| 159 | + } |
| 160 | + |
| 161 | +private: |
| 162 | + std::unordered_map<AllocationKey, std::shared_ptr<PersistentWorkspaceInterface>, AllocationKeyHash> mWorkspaces; |
| 163 | + std::unordered_map<AllocationKey, std::shared_ptr<IpcNvlsHandleWrapper>, AllocationKeyHash> mHandles; |
| 164 | +}; |
| 165 | + |
| 166 | +NvlsMemoryManager gNvlsMemoryManager; |
| 167 | +} // namespace |
| 168 | + |
| 169 | +namespace torch_ext |
| 170 | +{ |
| 171 | + |
| 172 | +class Fp4GemmAllreduceRunner : public torch::CustomClassHolder |
| 173 | +{ |
| 174 | +public: |
| 175 | + explicit Fp4GemmAllreduceRunner(at::ScalarType outputDtype, int64_t rank, torch::List<int64_t> group) |
| 176 | + : mOutputDtype(outputDtype) |
| 177 | + , mRank(rank) |
| 178 | + { |
| 179 | + for (int64_t rank : group) |
| 180 | + { |
| 181 | + mGroup.insert(static_cast<int>(rank)); |
| 182 | + } |
| 183 | + |
| 184 | + if (outputDtype == at::ScalarType::Half) |
| 185 | + { |
| 186 | + using Traits = GemmTypes<cutlass::float_e2m1_t, cutlass::float_e2m1_t, cutlass::half_t, cutlass::half_t, |
| 187 | + cutlass::float_ue4m3_t, cutlass::float_ue4m3_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, |
| 188 | + cutlass::layout::RowMajor, cutlass::layout::RowMajor>; |
| 189 | + mRunner = std::make_shared<GemmAllReduceImplRunner<Traits>>(); |
| 190 | + } |
| 191 | + else if (outputDtype == at::ScalarType::BFloat16) |
| 192 | + { |
| 193 | + using Traits = GemmTypes<cutlass::float_e2m1_t, cutlass::float_e2m1_t, cutlass::bfloat16_t, |
| 194 | + cutlass::bfloat16_t, cutlass::float_ue4m3_t, cutlass::float_ue4m3_t, cutlass::layout::RowMajor, |
| 195 | + cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, cutlass::layout::RowMajor>; |
| 196 | + mRunner = std::make_shared<GemmAllReduceImplRunner<Traits>>(); |
| 197 | + } |
| 198 | + else |
| 199 | + { |
| 200 | + C10_THROW_ERROR(NotImplementedError, "Unsupported input or output dtype"); |
| 201 | + } |
| 202 | + |
| 203 | + mConfigs = mRunner->getSupportedLaunchConfigs(); |
| 204 | + } |
| 205 | + |
| 206 | + at::Tensor runGemm(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& mat1Scale, |
| 207 | + at::Tensor const& mat2Scale, at::Tensor const& alpha, int64_t configIdx) const |
| 208 | + { |
| 209 | + if (configIdx < 0) |
| 210 | + configIdx = 0; |
| 211 | + |
| 212 | + TORCH_CHECK(configIdx < int64_t(mConfigs.size()), "configIdx out of bounds"); |
| 213 | + const int64_t M = mat1.size(0); |
| 214 | + const int64_t N = mat2.size(0); |
| 215 | + const int64_t K = mat1.size(1) * 2; |
| 216 | + |
| 217 | + GemmAllReduceImplInterface::ProblemArgs problemArgs; |
| 218 | + problemArgs.argProblemShape(M, N, K, 1); |
| 219 | + problemArgs.argA(mat1.data_ptr()); |
| 220 | + problemArgs.argB(mat2.data_ptr()); |
| 221 | + problemArgs.argAScale(mat1Scale.data_ptr()); |
| 222 | + problemArgs.argBScale(mat2Scale.data_ptr()); |
| 223 | + problemArgs.argC(nullptr); |
| 224 | + problemArgs.argAlphaPtr(reinterpret_cast<float const*>(alpha.const_data_ptr())); |
| 225 | + problemArgs.argBeta(0.f); |
| 226 | + problemArgs.argRanks(mRank, mGroup); |
| 227 | + problemArgs.argLaunchConfig(mConfigs[configIdx]); |
| 228 | + |
| 229 | + size_t dSize = M * N * c10::elementSize(mOutputDtype); |
| 230 | + auto stream = at::cuda::getCurrentCUDAStream(mat1.get_device()); |
| 231 | + |
| 232 | + auto handle = gNvlsMemoryManager.getD(dSize, AllocationKey{stream, mGroup}); |
| 233 | + problemArgs.argD((void*) handle->uc_ptr, (void*) handle->mc_ptr, (void**) handle->ipc_uc_ptrs.data()); |
| 234 | + |
| 235 | + auto workspace = gNvlsMemoryManager.getWorkspace(mRunner.get(), problemArgs, AllocationKey{stream, mGroup}); |
| 236 | + problemArgs.argWorkspace(workspace); |
| 237 | + mRunner->run(problemArgs, stream); |
| 238 | + auto options = mat1.options().dtype(mOutputDtype); |
| 239 | + auto D = at::from_blob((void*) handle->uc_ptr, {M, N}, {N, 1}, options); |
| 240 | + |
| 241 | + std::cout << "benzh@Fp4GemmAllreduceRunner local rank: " << mRank; |
| 242 | + return D; |
| 243 | + } |
| 244 | + |
| 245 | + int64_t getNumConfigs() const |
| 246 | + { |
| 247 | + return static_cast<int64_t>(mConfigs.size()); |
| 248 | + } |
| 249 | + |
| 250 | +private: |
| 251 | + at::ScalarType mOutputDtype; |
| 252 | + int mRank; |
| 253 | + std::set<int> mGroup; |
| 254 | + std::shared_ptr<GemmAllReduceImplInterface> mRunner{nullptr}; |
| 255 | + std::vector<GemmAllReduceImplInterface::LaunchConfig> mConfigs; |
| 256 | +}; |
| 257 | + |
| 258 | +} // namespace torch_ext |
| 259 | + |
| 260 | +TORCH_LIBRARY_FRAGMENT(trtllm, m) |
| 261 | +{ |
| 262 | + m.class_<torch_ext::Fp4GemmAllreduceRunner>("Fp4GemmAllreduceRunner") |
| 263 | + .def(torch::init<at::ScalarType, int64_t, torch::List<int64_t>>()) |
| 264 | + .def("run_gemm", &torch_ext::Fp4GemmAllreduceRunner::runGemm) |
| 265 | + .def("get_num_configs", &torch_ext::Fp4GemmAllreduceRunner::getNumConfigs); |
| 266 | +} |
0 commit comments