|
| 1 | +/* |
| 2 | + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +#include "cutlass_extensions/gemm_configs.h" |
| 18 | + |
| 19 | +#include "tensorrt_llm/common/cudaUtils.h" |
| 20 | +#include "tensorrt_llm/kernels/cutlass_kernels/include/allreduce_gemm_runner.h" |
| 21 | +#include "tensorrt_llm/runtime/ipcNvlsMemory.h" |
| 22 | +#include "tensorrt_llm/thop/thUtils.h" |
| 23 | + |
| 24 | +#include <cstddef> |
| 25 | +#include <cuda_fp16.h> |
| 26 | + |
| 27 | +#include <cstdint> |
| 28 | +#include <functional> |
| 29 | +#include <type_traits> |
| 30 | +#include <vector> |
| 31 | + |
| 32 | +using tensorrt_llm::kernels::opened_cutlass_kernels::GemmAllReduceImplRunner; |
| 33 | +using tensorrt_llm::kernels::opened_cutlass_kernels::GemmAllReduceImplInterface; |
| 34 | +using tensorrt_llm::kernels::opened_cutlass_kernels::GemmTypes; |
| 35 | +using tensorrt_llm::kernels::opened_cutlass_kernels::PersistentWorkspaceInterface; |
| 36 | + |
| 37 | +namespace torch_ext |
| 38 | +{ |
| 39 | +PersistentWorkspaceInterface* getWorkspace( |
| 40 | + GemmAllReduceImplInterface* runner, GemmAllReduceImplInterface::ProblemArgs const& problem) |
| 41 | +{ |
| 42 | + thread_local std::shared_ptr<PersistentWorkspaceInterface> curWorkspace; |
| 43 | + thread_local size_t curWorkspaceSize = 0; |
| 44 | + auto newWorkspace = runner->getPersistentWorkspace(problem); |
| 45 | + if (newWorkspace->size() > curWorkspaceSize) |
| 46 | + { |
| 47 | + TLLM_LOG_WARNING( |
| 48 | + "Fp4GemmAllreduceRunner workspace is not enough, allocate new workspace", newWorkspace->size()); |
| 49 | + newWorkspace->allocate(); |
| 50 | + curWorkspaceSize = newWorkspace->size(); |
| 51 | + curWorkspace = newWorkspace; |
| 52 | + } |
| 53 | + return curWorkspace.get(); |
| 54 | +} |
| 55 | + |
| 56 | +class Fp4GemmAllreduceRunner : public torch::CustomClassHolder |
| 57 | +{ |
| 58 | +public: |
| 59 | + explicit Fp4GemmAllreduceRunner(at::ScalarType outputDtype, int64_t rank, torch::List<int64_t> group) |
| 60 | + : mOutputDtype(outputDtype) |
| 61 | + , mRank(rank) |
| 62 | + { |
| 63 | + for (int64_t rank : group) |
| 64 | + { |
| 65 | + mGroup.insert(static_cast<int>(rank)); |
| 66 | + } |
| 67 | + |
| 68 | + if (outputDtype == at::ScalarType::Half) |
| 69 | + { |
| 70 | + using Traits = GemmTypes<cutlass::float_e2m1_t, cutlass::float_e2m1_t, cutlass::half_t, cutlass::half_t, |
| 71 | + cutlass::float_ue4m3_t, cutlass::float_ue4m3_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, |
| 72 | + cutlass::layout::RowMajor, cutlass::layout::RowMajor>; |
| 73 | + mRunner = std::make_shared<GemmAllReduceImplRunner<Traits>>(); |
| 74 | + } |
| 75 | + else if (outputDtype == at::ScalarType::BFloat16) |
| 76 | + { |
| 77 | + using Traits = GemmTypes<cutlass::float_e2m1_t, cutlass::float_e2m1_t, cutlass::bfloat16_t, |
| 78 | + cutlass::bfloat16_t, cutlass::float_ue4m3_t, cutlass::float_ue4m3_t, cutlass::layout::RowMajor, |
| 79 | + cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, cutlass::layout::RowMajor>; |
| 80 | + mRunner = std::make_shared<GemmAllReduceImplRunner<Traits>>(); |
| 81 | + } |
| 82 | + else |
| 83 | + { |
| 84 | + C10_THROW_ERROR(NotImplementedError, "Unsupported input or output dtype"); |
| 85 | + } |
| 86 | + |
| 87 | + mConfigs = mRunner->getSupportedLaunchConfigs(); |
| 88 | + } |
| 89 | + |
| 90 | + at::Tensor runGemm(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& mat1Scale, |
| 91 | + at::Tensor const& mat2Scale, at::Tensor const& alpha, int64_t configIdx) const |
| 92 | + { |
| 93 | + if (configIdx < 0) |
| 94 | + configIdx = 0; |
| 95 | + |
| 96 | + TORCH_CHECK(configIdx < int64_t(mConfigs.size()), "configIdx out of bounds"); |
| 97 | + const int64_t M = mat1.size(0); |
| 98 | + const int64_t N = mat2.size(0); |
| 99 | + const int64_t K = mat1.size(1) * 2; |
| 100 | + |
| 101 | + GemmAllReduceImplInterface::ProblemArgs problemArgs; |
| 102 | + problemArgs.argProblemShape(M, N, K, 1); |
| 103 | + problemArgs.argA(mat1.data_ptr()); |
| 104 | + problemArgs.argB(mat2.data_ptr()); |
| 105 | + problemArgs.argAScale(mat1Scale.data_ptr()); |
| 106 | + problemArgs.argBScale(mat2Scale.data_ptr()); |
| 107 | + problemArgs.argC(nullptr); |
| 108 | + problemArgs.argAlphaPtr(reinterpret_cast<float const*>(alpha.const_data_ptr())); |
| 109 | + problemArgs.argBeta(0.f); |
| 110 | + problemArgs.argRanks(mRank, mGroup); |
| 111 | + problemArgs.argLaunchConfig(mConfigs[configIdx]); |
| 112 | + |
| 113 | + size_t dSize = M * N * c10::elementSize(mOutputDtype); |
| 114 | + auto handle = tensorrt_llm::runtime::ipcNvlsAllocate(dSize, mGroup); |
| 115 | + problemArgs.argD((void*) handle->uc_ptr, (void*) handle->mc_ptr, (void**) handle->ipc_uc_ptrs.data()); |
| 116 | + |
| 117 | + auto workspace = getWorkspace(mRunner.get(), problemArgs); |
| 118 | + problemArgs.argWorkspace(workspace); |
| 119 | + |
| 120 | + auto stream = at::cuda::getCurrentCUDAStream(mat1.get_device()); |
| 121 | + mRunner->run(problemArgs, stream); |
| 122 | + |
| 123 | + auto options = mat1.options().dtype(mOutputDtype); |
| 124 | + auto deleter = [=](void* unused) { ipcNvlsFree(handle); }; |
| 125 | + auto D = at::from_blob((void*) handle->uc_ptr, {M, N}, {N, 1}, deleter, options); |
| 126 | + return D; |
| 127 | + } |
| 128 | + |
| 129 | + int64_t getNumConfigs() const |
| 130 | + { |
| 131 | + return static_cast<int64_t>(mConfigs.size()); |
| 132 | + } |
| 133 | + |
| 134 | +private: |
| 135 | + at::ScalarType mOutputDtype; |
| 136 | + int mRank; |
| 137 | + std::set<int> mGroup; |
| 138 | + std::shared_ptr<GemmAllReduceImplInterface> mRunner{nullptr}; |
| 139 | + std::vector<GemmAllReduceImplInterface::LaunchConfig> mConfigs; |
| 140 | +}; |
| 141 | + |
| 142 | +} // namespace torch_ext |
| 143 | + |
| 144 | +TORCH_LIBRARY_FRAGMENT(trtllm, m) |
| 145 | +{ |
| 146 | + m.class_<torch_ext::Fp4GemmAllreduceRunner>("Fp4GemmAllreduceRunner") |
| 147 | + .def(torch::init<at::ScalarType, int64_t, torch::List<int64_t>>()) |
| 148 | + .def("run_gemm", &torch_ext::Fp4GemmAllreduceRunner::runGemm) |
| 149 | + .def("get_num_configs", &torch_ext::Fp4GemmAllreduceRunner::getNumConfigs); |
| 150 | +} |
0 commit comments