Skip to content

Commit da85f02

Browse files
committed
add fp4 gemm + allreduce
Signed-off-by: benzh <[email protected]>
1 parent 4317859 commit da85f02

File tree

7 files changed

+651
-18
lines changed

7 files changed

+651
-18
lines changed

cpp/tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/allreduce_gemm_impl_sm100.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ class GemmAllReduceImplTwoshot_Sm100 : public GemmAllReduceImplInterface
141141
// Epilogue
142142
////////////////
143143
using FusionCallbacks = cutlass::epilogue::fusion::LinearCombination<ElementD, float, void, float>;
144-
using TileBarrierType = cutlass::MulticastSystemBarrier<cutlass::detail::SyncNoOp, true>;
144+
using TileBarrierType = cutlass::MulticastSystemBarrier<cutlass::detail::SyncNoOp, false>;
145145
using EpilogueScheduleType = typename MmaAdapter<MmaType, IsFP4>::EpilogueSchedule;
146146
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
147147
using FusionOp

cpp/tensorrt_llm/thop/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ add_library(
104104
loraOp.cpp
105105
finegrained_mixed_dtype_gemm_thop.cpp
106106
tinygemm2.cpp
107-
dsv3RopeOp.cpp)
107+
dsv3RopeOp.cpp
108+
fusedGemmAllreduceOp.cpp)
108109
set_property(TARGET th_common PROPERTY POSITION_INDEPENDENT_CODE ON)
109110
target_link_libraries(
110111
th_common PRIVATE ${TORCH_LIBRARIES} th_utils ${Python3_LIBRARIES}
Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
/*
2+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "cutlass_extensions/gemm_configs.h"
18+
19+
#include "tensorrt_llm/common/cudaUtils.h"
20+
#include "tensorrt_llm/kernels/cutlass_kernels/include/allreduce_gemm_runner.h"
21+
#include "tensorrt_llm/runtime/ipcNvlsMemory.h"
22+
#include "tensorrt_llm/thop/thUtils.h"
23+
24+
#include <ATen/cuda/EmptyTensor.h>
25+
26+
#include <cstddef>
27+
#include <cuda_fp16.h>
28+
29+
#include <cstdint>
30+
#include <functional>
31+
#include <type_traits>
32+
#include <vector>
33+
34+
using tensorrt_llm::kernels::opened_cutlass_kernels::GemmAllReduceImplRunner;
35+
using tensorrt_llm::kernels::opened_cutlass_kernels::GemmAllReduceImplInterface;
36+
using tensorrt_llm::kernels::opened_cutlass_kernels::GemmTypes;
37+
using tensorrt_llm::kernels::opened_cutlass_kernels::PersistentWorkspaceInterface;
38+
39+
namespace
40+
{
41+
struct AllocationKey
42+
{
43+
int64_t device_index;
44+
std::set<int> group;
45+
46+
bool operator==(AllocationKey const& other) const
47+
{
48+
return device_index == other.device_index && group == other.group;
49+
}
50+
51+
std::string toString() const
52+
{
53+
std::stringstream ss;
54+
ss << "AllocationKey(device: " << device_index << ", group: [";
55+
for (int rank : group)
56+
{
57+
ss << rank << ", ";
58+
}
59+
ss << "])";
60+
return ss.str();
61+
}
62+
};
63+
64+
struct AllocationKeyHash
65+
{
66+
size_t operator()(AllocationKey const& key) const
67+
{
68+
size_t seed = 0;
69+
70+
// Hash the device index
71+
hash_combine(seed, key.device_index);
72+
73+
// Hash the set elements
74+
for (auto const& elem : key.group)
75+
{
76+
hash_combine(seed, elem);
77+
}
78+
79+
return seed;
80+
}
81+
82+
private:
83+
template <typename T>
84+
static void hash_combine(size_t& seed, T const& val)
85+
{
86+
seed ^= std::hash<T>()(val) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
87+
}
88+
};
89+
90+
class IpcNvlsHandleWrapper
91+
{
92+
public:
93+
IpcNvlsHandleWrapper(size_t size, std::set<int> groups)
94+
: mSize(size)
95+
{
96+
mHandle = tensorrt_llm::runtime::ipcNvlsAllocate(size, groups);
97+
}
98+
99+
tensorrt_llm::runtime::IpcNvlsHandle* getHandle() const
100+
{
101+
return mHandle;
102+
}
103+
104+
size_t getSize() const
105+
{
106+
return mSize;
107+
}
108+
109+
~IpcNvlsHandleWrapper()
110+
{
111+
tensorrt_llm::runtime::ipcNvlsFree(mHandle);
112+
}
113+
114+
private:
115+
size_t mSize;
116+
tensorrt_llm::runtime::IpcNvlsHandle* mHandle;
117+
};
118+
119+
std::once_flag init_flag;
120+
121+
size_t getPreferredWorkspaceSize()
122+
{
123+
// 128MB
124+
static size_t preferredWorkspaceSize = 134217728;
125+
std::call_once(init_flag,
126+
[&]()
127+
{
128+
char const* envWorkspaceSize = std::getenv("TRTLLM_GEMM_ALLREDUCE_WORKSPACE_SIZE");
129+
size_t workspaceSize = 0;
130+
if (envWorkspaceSize != nullptr)
131+
{
132+
workspaceSize = std::atoi(envWorkspaceSize);
133+
}
134+
preferredWorkspaceSize = std::max(preferredWorkspaceSize, workspaceSize);
135+
});
136+
return preferredWorkspaceSize;
137+
}
138+
139+
class GemmAllreduceNvlsMemoryManager
140+
{
141+
public:
142+
GemmAllreduceNvlsMemoryManager()
143+
{
144+
TLLM_LOG_INFO("GemmAllreduceNvlsMemoryManager constructor");
145+
}
146+
147+
~GemmAllreduceNvlsMemoryManager()
148+
{
149+
TLLM_LOG_INFO("GemmAllreduceNvlsMemoryManager destructor");
150+
}
151+
152+
std::pair<PersistentWorkspaceInterface*, tensorrt_llm::runtime::IpcNvlsHandle*> getWorkspace(
153+
GemmAllReduceImplInterface* runner, GemmAllReduceImplInterface::ProblemArgs const& problem,
154+
AllocationKey const& key)
155+
{
156+
int M = std::get<0>(problem.problem_size);
157+
int N = std::get<1>(problem.problem_size);
158+
size_t requiredSize = M * N * 2;
159+
size_t preferredWorkspaceSize = getPreferredWorkspaceSize();
160+
if (requiredSize > preferredWorkspaceSize)
161+
{
162+
std::stringstream ss;
163+
ss << "Please set TRTLLM_GEMM_ALLREDUCE_WORKSPACE_SIZE to at least " << requiredSize << " bytes";
164+
C10_THROW_ERROR(ErrorAlwaysShowCppStacktrace, ss.str().c_str());
165+
}
166+
167+
auto handle = mHandles[key];
168+
if (handle == nullptr)
169+
{
170+
TLLM_LOG_INFO("Creating allreduce workspace for %s", key.toString().c_str());
171+
handle = std::make_shared<IpcNvlsHandleWrapper>(preferredWorkspaceSize, key.group);
172+
GemmAllReduceImplInterface::ProblemArgs tmpArgs;
173+
int maxN = 16384;
174+
int maxM = preferredWorkspaceSize / (maxN * 2);
175+
tmpArgs.argProblemShape(maxM, maxN, 512, 1)
176+
.argRanks(problem.rank, problem.ranks)
177+
.argLaunchConfig(runner->getSupportedLaunchConfigs()[0]);
178+
auto workspace = runner->getPersistentWorkspace(tmpArgs);
179+
workspace->allocate();
180+
mWorkspaces[key] = workspace;
181+
mHandles[key] = handle;
182+
}
183+
return std::make_pair(mWorkspaces[key].get(), mHandles[key]->getHandle());
184+
}
185+
186+
private:
187+
std::unordered_map<AllocationKey, std::shared_ptr<PersistentWorkspaceInterface>, AllocationKeyHash> mWorkspaces;
188+
std::unordered_map<AllocationKey, std::shared_ptr<IpcNvlsHandleWrapper>, AllocationKeyHash> mHandles;
189+
};
190+
191+
GemmAllreduceNvlsMemoryManager* getGemmAllreduceNvlsMemoryManager()
192+
{
193+
static GemmAllreduceNvlsMemoryManager gNvlsMemoryManager;
194+
return &gNvlsMemoryManager;
195+
}
196+
197+
at::Tensor runGemmImpl(GemmAllReduceImplInterface* runner, GemmAllReduceImplInterface::ProblemArgs& problem,
198+
at::ScalarType outputDtype, c10::cuda::CUDAStream stream)
199+
{
200+
AllocationKey key{stream.device_index(), problem.ranks};
201+
auto [workspace, handle] = getGemmAllreduceNvlsMemoryManager()->getWorkspace(runner, problem, key);
202+
problem.argD((void*) handle->uc_ptr, (void*) handle->mc_ptr, (void**) handle->ipc_uc_ptrs.data());
203+
problem.argWorkspace(workspace);
204+
runner->run(problem, stream);
205+
size_t dSize
206+
= std::get<0>(problem.problem_size) * std::get<1>(problem.problem_size) * c10::elementSize(outputDtype);
207+
auto D = at::detail::empty_cuda({std::get<0>(problem.problem_size), std::get<1>(problem.problem_size)}, outputDtype,
208+
stream.device(), std::nullopt);
209+
TLLM_CUDA_CHECK(cudaMemcpyAsync(
210+
D.data_ptr(), reinterpret_cast<void const*>(handle->uc_ptr), dSize, cudaMemcpyDeviceToDevice, stream));
211+
return D;
212+
}
213+
} // namespace
214+
215+
namespace torch_ext
216+
{
217+
218+
class Fp4GemmAllreduceRunner : public torch::CustomClassHolder
219+
{
220+
public:
221+
explicit Fp4GemmAllreduceRunner(at::ScalarType outputDtype, int64_t rank, torch::List<int64_t> group)
222+
: mOutputDtype(outputDtype)
223+
, mRank(rank)
224+
{
225+
for (int64_t rank : group)
226+
{
227+
mGroup.insert(static_cast<int>(rank));
228+
}
229+
230+
if (outputDtype == at::ScalarType::Half)
231+
{
232+
using Traits = GemmTypes<cutlass::float_e2m1_t, cutlass::float_e2m1_t, cutlass::half_t, cutlass::half_t,
233+
cutlass::float_ue4m3_t, cutlass::float_ue4m3_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor,
234+
cutlass::layout::RowMajor, cutlass::layout::RowMajor>;
235+
mRunner = std::make_shared<GemmAllReduceImplRunner<Traits>>();
236+
}
237+
else if (outputDtype == at::ScalarType::BFloat16)
238+
{
239+
using Traits = GemmTypes<cutlass::float_e2m1_t, cutlass::float_e2m1_t, cutlass::bfloat16_t,
240+
cutlass::bfloat16_t, cutlass::float_ue4m3_t, cutlass::float_ue4m3_t, cutlass::layout::RowMajor,
241+
cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, cutlass::layout::RowMajor>;
242+
mRunner = std::make_shared<GemmAllReduceImplRunner<Traits>>();
243+
}
244+
else
245+
{
246+
C10_THROW_ERROR(NotImplementedError, "Unsupported input or output dtype");
247+
}
248+
249+
mConfigs = mRunner->getSupportedLaunchConfigs();
250+
}
251+
252+
at::Tensor runGemm(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& mat1Scale,
253+
at::Tensor const& mat2Scale, at::Tensor const& alpha, int64_t configIdx) const
254+
{
255+
if (configIdx < 0)
256+
configIdx = 0;
257+
258+
TORCH_CHECK(configIdx < int64_t(mConfigs.size()), "configIdx out of bounds");
259+
const int64_t M = mat1.size(0);
260+
const int64_t N = mat2.size(0);
261+
const int64_t K = mat1.size(1) * 2;
262+
263+
GemmAllReduceImplInterface::ProblemArgs problemArgs;
264+
problemArgs.argProblemShape(M, N, K, 1);
265+
problemArgs.argA(mat1.data_ptr());
266+
problemArgs.argB(mat2.data_ptr());
267+
problemArgs.argAScale(mat1Scale.data_ptr());
268+
problemArgs.argBScale(mat2Scale.data_ptr());
269+
problemArgs.argC(nullptr);
270+
problemArgs.argAlphaPtr(reinterpret_cast<float const*>(alpha.const_data_ptr()));
271+
problemArgs.argBeta(0.f);
272+
problemArgs.argRanks(mRank, mGroup);
273+
problemArgs.argLaunchConfig(mConfigs[configIdx]);
274+
275+
auto stream = at::cuda::getCurrentCUDAStream(mat1.get_device());
276+
return runGemmImpl(mRunner.get(), problemArgs, mOutputDtype, stream);
277+
}
278+
279+
int64_t getNumConfigs() const
280+
{
281+
return static_cast<int64_t>(mConfigs.size());
282+
}
283+
284+
private:
285+
at::ScalarType mOutputDtype;
286+
int mRank;
287+
std::set<int> mGroup;
288+
std::shared_ptr<GemmAllReduceImplInterface> mRunner{nullptr};
289+
std::vector<GemmAllReduceImplInterface::LaunchConfig> mConfigs;
290+
};
291+
292+
} // namespace torch_ext
293+
294+
TORCH_LIBRARY_FRAGMENT(trtllm, m)
295+
{
296+
m.class_<torch_ext::Fp4GemmAllreduceRunner>("Fp4GemmAllreduceRunner")
297+
.def(torch::init<at::ScalarType, int64_t, torch::List<int64_t>>())
298+
.def("run_gemm", &torch_ext::Fp4GemmAllreduceRunner::runGemm)
299+
.def("get_num_configs", &torch_ext::Fp4GemmAllreduceRunner::getNumConfigs);
300+
}

0 commit comments

Comments
 (0)