Skip to content

Commit 3c42ff0

Browse files
committed
add fp4 gemm + allreduce
Signed-off-by: benzh <[email protected]>
1 parent af2849c commit 3c42ff0

File tree

5 files changed

+615
-11
lines changed

5 files changed

+615
-11
lines changed

cpp/tensorrt_llm/thop/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ add_library(
104104
loraOp.cpp
105105
finegrained_mixed_dtype_gemm_thop.cpp
106106
tinygemm2.cpp
107-
dsv3RopeOp.cpp)
107+
dsv3RopeOp.cpp
108+
fusedGemmAllreduceOp.cpp)
108109
set_property(TARGET th_common PROPERTY POSITION_INDEPENDENT_CODE ON)
109110
target_link_libraries(
110111
th_common PRIVATE ${TORCH_LIBRARIES} th_utils ${Python3_LIBRARIES}
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
/*
2+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "cutlass_extensions/gemm_configs.h"
18+
19+
#include "tensorrt_llm/common/cudaUtils.h"
20+
#include "tensorrt_llm/kernels/cutlass_kernels/include/allreduce_gemm_runner.h"
21+
#include "tensorrt_llm/runtime/ipcNvlsMemory.h"
22+
#include "tensorrt_llm/thop/thUtils.h"
23+
24+
#include <cstddef>
25+
#include <cuda_fp16.h>
26+
27+
#include <cstdint>
28+
#include <functional>
29+
#include <type_traits>
30+
#include <vector>
31+
32+
using tensorrt_llm::kernels::opened_cutlass_kernels::GemmAllReduceImplRunner;
33+
using tensorrt_llm::kernels::opened_cutlass_kernels::GemmAllReduceImplInterface;
34+
using tensorrt_llm::kernels::opened_cutlass_kernels::GemmTypes;
35+
using tensorrt_llm::kernels::opened_cutlass_kernels::PersistentWorkspaceInterface;
36+
37+
namespace
38+
{
39+
struct AllocationKey
40+
{
41+
c10::cuda::CUDAStream stream;
42+
std::set<int> group;
43+
44+
bool operator==(AllocationKey const& other) const
45+
{
46+
return stream.id() == other.stream.id() && stream.device_index() == other.stream.device_index()
47+
&& group == other.group;
48+
}
49+
50+
std::string toString() const
51+
{
52+
std::stringstream ss;
53+
ss << "AllocationKey(stream: " << stream.id() << ", device: " << (int) stream.device_index() << ", group: [";
54+
for (int rank : group)
55+
{
56+
ss << rank << ", ";
57+
}
58+
ss << "])";
59+
return ss.str();
60+
}
61+
};
62+
63+
struct AllocationKeyHash
64+
{
65+
size_t operator()(AllocationKey const& key) const
66+
{
67+
size_t seed = 0;
68+
69+
// Hash the stream (using stream ID and device index)
70+
hash_combine(seed, key.stream.id());
71+
hash_combine(seed, key.stream.device_index());
72+
73+
// Hash the set elements
74+
for (auto const& elem : key.group)
75+
{
76+
hash_combine(seed, elem);
77+
}
78+
79+
return seed;
80+
}
81+
82+
private:
83+
template <typename T>
84+
static void hash_combine(size_t& seed, T const& val)
85+
{
86+
seed ^= std::hash<T>()(val) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
87+
}
88+
};
89+
90+
class IpcNvlsHandleWrapper
91+
{
92+
public:
93+
IpcNvlsHandleWrapper(size_t size, std::set<int> groups)
94+
: mSize(size)
95+
{
96+
mHandle = tensorrt_llm::runtime::ipcNvlsAllocate(size, groups);
97+
}
98+
99+
tensorrt_llm::runtime::IpcNvlsHandle* getHandle() const
100+
{
101+
return mHandle;
102+
}
103+
104+
size_t getSize() const
105+
{
106+
return mSize;
107+
}
108+
109+
~IpcNvlsHandleWrapper()
110+
{
111+
tensorrt_llm::runtime::ipcNvlsFree(mHandle);
112+
}
113+
114+
private:
115+
size_t mSize;
116+
tensorrt_llm::runtime::IpcNvlsHandle* mHandle;
117+
};
118+
119+
class NvlsMemoryManager
120+
{
121+
public:
122+
PersistentWorkspaceInterface* getWorkspace(GemmAllReduceImplInterface* runner,
123+
GemmAllReduceImplInterface::ProblemArgs const& problem, AllocationKey const& key)
124+
{
125+
auto newWorkspace = runner->getPersistentWorkspace(problem);
126+
auto curWorkspace = mWorkspaces[key];
127+
128+
if (curWorkspace == nullptr || curWorkspace->size() < newWorkspace->size())
129+
{
130+
TLLM_LOG_WARNING("NvlsMemoryManager allocate workspace: key=%s, size=%zu.", key.toString().c_str(),
131+
newWorkspace->size());
132+
newWorkspace->allocate();
133+
mWorkspaces[key] = newWorkspace;
134+
}
135+
return mWorkspaces[key].get();
136+
}
137+
138+
tensorrt_llm::runtime::IpcNvlsHandle* getD(size_t size, AllocationKey const& key)
139+
{
140+
constexpr size_t PREFER_SIZE = 1024 * 16384 * 2;
141+
auto handle = mHandles[key];
142+
if (handle == nullptr || handle->getSize() < size)
143+
{
144+
if (size < PREFER_SIZE)
145+
{
146+
handle = std::make_shared<IpcNvlsHandleWrapper>(PREFER_SIZE, key.group);
147+
TLLM_LOG_WARNING(
148+
"NvlsMemoryManager allocate D: key=%s, size=%zu bytes.", key.toString().c_str(), PREFER_SIZE);
149+
}
150+
else
151+
{
152+
handle = std::make_shared<IpcNvlsHandleWrapper>(size, key.group);
153+
TLLM_LOG_WARNING("NvlsMemoryManager allocate D: key=%s, size=%zu bytes.", key.toString().c_str(), size);
154+
}
155+
mHandles[key] = handle;
156+
}
157+
158+
return mHandles[key]->getHandle();
159+
}
160+
161+
private:
162+
std::unordered_map<AllocationKey, std::shared_ptr<PersistentWorkspaceInterface>, AllocationKeyHash> mWorkspaces;
163+
std::unordered_map<AllocationKey, std::shared_ptr<IpcNvlsHandleWrapper>, AllocationKeyHash> mHandles;
164+
};
165+
166+
NvlsMemoryManager gNvlsMemoryManager;
167+
} // namespace
168+
169+
namespace torch_ext
170+
{
171+
172+
class Fp4GemmAllreduceRunner : public torch::CustomClassHolder
173+
{
174+
public:
175+
explicit Fp4GemmAllreduceRunner(at::ScalarType outputDtype, int64_t rank, torch::List<int64_t> group)
176+
: mOutputDtype(outputDtype)
177+
, mRank(rank)
178+
{
179+
for (int64_t rank : group)
180+
{
181+
mGroup.insert(static_cast<int>(rank));
182+
}
183+
184+
if (outputDtype == at::ScalarType::Half)
185+
{
186+
using Traits = GemmTypes<cutlass::float_e2m1_t, cutlass::float_e2m1_t, cutlass::half_t, cutlass::half_t,
187+
cutlass::float_ue4m3_t, cutlass::float_ue4m3_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor,
188+
cutlass::layout::RowMajor, cutlass::layout::RowMajor>;
189+
mRunner = std::make_shared<GemmAllReduceImplRunner<Traits>>();
190+
}
191+
else if (outputDtype == at::ScalarType::BFloat16)
192+
{
193+
using Traits = GemmTypes<cutlass::float_e2m1_t, cutlass::float_e2m1_t, cutlass::bfloat16_t,
194+
cutlass::bfloat16_t, cutlass::float_ue4m3_t, cutlass::float_ue4m3_t, cutlass::layout::RowMajor,
195+
cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, cutlass::layout::RowMajor>;
196+
mRunner = std::make_shared<GemmAllReduceImplRunner<Traits>>();
197+
}
198+
else
199+
{
200+
C10_THROW_ERROR(NotImplementedError, "Unsupported input or output dtype");
201+
}
202+
203+
mConfigs = mRunner->getSupportedLaunchConfigs();
204+
}
205+
206+
at::Tensor runGemm(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& mat1Scale,
207+
at::Tensor const& mat2Scale, at::Tensor const& alpha, int64_t configIdx) const
208+
{
209+
if (configIdx < 0)
210+
configIdx = 0;
211+
212+
TORCH_CHECK(configIdx < int64_t(mConfigs.size()), "configIdx out of bounds");
213+
const int64_t M = mat1.size(0);
214+
const int64_t N = mat2.size(0);
215+
const int64_t K = mat1.size(1) * 2;
216+
217+
GemmAllReduceImplInterface::ProblemArgs problemArgs;
218+
problemArgs.argProblemShape(M, N, K, 1);
219+
problemArgs.argA(mat1.data_ptr());
220+
problemArgs.argB(mat2.data_ptr());
221+
problemArgs.argAScale(mat1Scale.data_ptr());
222+
problemArgs.argBScale(mat2Scale.data_ptr());
223+
problemArgs.argC(nullptr);
224+
problemArgs.argAlphaPtr(reinterpret_cast<float const*>(alpha.const_data_ptr()));
225+
problemArgs.argBeta(0.f);
226+
problemArgs.argRanks(mRank, mGroup);
227+
problemArgs.argLaunchConfig(mConfigs[configIdx]);
228+
229+
size_t dSize = M * N * c10::elementSize(mOutputDtype);
230+
auto stream = at::cuda::getCurrentCUDAStream(mat1.get_device());
231+
232+
auto handle = gNvlsMemoryManager.getD(dSize, AllocationKey{stream, mGroup});
233+
problemArgs.argD((void*) handle->uc_ptr, (void*) handle->mc_ptr, (void**) handle->ipc_uc_ptrs.data());
234+
235+
auto workspace = gNvlsMemoryManager.getWorkspace(mRunner.get(), problemArgs, AllocationKey{stream, mGroup});
236+
problemArgs.argWorkspace(workspace);
237+
mRunner->run(problemArgs, stream);
238+
auto options = mat1.options().dtype(mOutputDtype);
239+
auto D = at::from_blob((void*) handle->uc_ptr, {M, N}, {N, 1}, options);
240+
241+
return D;
242+
}
243+
244+
int64_t getNumConfigs() const
245+
{
246+
return static_cast<int64_t>(mConfigs.size());
247+
}
248+
249+
private:
250+
at::ScalarType mOutputDtype;
251+
int mRank;
252+
std::set<int> mGroup;
253+
std::shared_ptr<GemmAllReduceImplInterface> mRunner{nullptr};
254+
std::vector<GemmAllReduceImplInterface::LaunchConfig> mConfigs;
255+
};
256+
257+
} // namespace torch_ext
258+
259+
TORCH_LIBRARY_FRAGMENT(trtllm, m)
260+
{
261+
m.class_<torch_ext::Fp4GemmAllreduceRunner>("Fp4GemmAllreduceRunner")
262+
.def(torch::init<at::ScalarType, int64_t, torch::List<int64_t>>())
263+
.def("run_gemm", &torch_ext::Fp4GemmAllreduceRunner::runGemm)
264+
.def("get_num_configs", &torch_ext::Fp4GemmAllreduceRunner::getNumConfigs);
265+
}

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1621,3 +1621,99 @@ def record_stream(tensor: torch.Tensor, stream_id: int) -> None:
16211621
stream = get_stream(stream_id)
16221622
assert stream is not None
16231623
tensor.record_stream(stream)
1624+
1625+
1626+
class Fp4GemmAllreduceRunner(TunableRunner):
1627+
runner_dict = dict()
1628+
tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec(
1629+
0, 0, get_last_power_of_2_num_tokens_buckets,
1630+
last_positive_power_of_2), ),
1631+
constraint_specs=(ConstraintSpec(
1632+
2, 0, fp4_scale_infer_shape), ))
1633+
1634+
def __init__(
1635+
self,
1636+
output_dtype: torch.dtype,
1637+
tp_rank: int,
1638+
tp_group: List[int],
1639+
):
1640+
self.output_dtype = output_dtype
1641+
self.tp_rank = tp_rank
1642+
self.tp_group_str = '-'.join(str(g) for g in tp_group)
1643+
instance_key = (output_dtype, self.tp_group_str)
1644+
if instance_key not in Fp4GemmAllreduceRunner.runner_dict:
1645+
Fp4GemmAllreduceRunner.runner_dict[
1646+
instance_key] = torch.classes.trtllm.Fp4GemmAllreduceRunner(
1647+
output_dtype, tp_rank, tp_group)
1648+
self.fp4_gemm_all_reduce_runner = Fp4GemmAllreduceRunner.runner_dict[
1649+
instance_key]
1650+
1651+
def unique_id(self):
1652+
return (self.output_dtype, self.tp_group_str)
1653+
1654+
def get_valid_tactics(self, inputs: List[torch.Tensor],
1655+
profile: OptimizationProfile, **kwargs) -> List[int]:
1656+
return list(range(self.fp4_gemm_all_reduce_runner.get_num_configs()))
1657+
1658+
def forward(
1659+
self,
1660+
inputs: List[torch.Tensor],
1661+
tactic: int = 0,
1662+
) -> torch.Tensor:
1663+
mat1, mat2, mat1_scale, mat2_scale, global_scale = inputs
1664+
return self.fp4_gemm_all_reduce_runner.run_gemm(
1665+
mat1,
1666+
mat2,
1667+
mat1_scale,
1668+
mat2_scale,
1669+
global_scale,
1670+
tactic,
1671+
)
1672+
1673+
1674+
@torch.library.custom_op("trtllm::nvfp4_gemm_allreduce", mutates_args=())
1675+
def nvfp4_gemm_allreduce(
1676+
act_fp4: torch.Tensor,
1677+
weight: torch.Tensor,
1678+
act_sf: torch.Tensor,
1679+
weight_scale: torch.Tensor,
1680+
alpha: torch.Tensor,
1681+
output_dtype: torch.dtype,
1682+
tp_rank: int,
1683+
tp_group: List[int],
1684+
) -> torch.Tensor:
1685+
AutoTuner.get()
1686+
1687+
# Use Cutlass runner with predefined configs
1688+
nvfp4_gemm_allreduce_runner = Fp4GemmAllreduceRunner(
1689+
output_dtype, tp_rank, tp_group)
1690+
1691+
# TODO: Enable auto-tuning
1692+
# runner_type = type(nvfp4_gemm_allreduce_runner).__name__
1693+
# _, best_tactic = tuner.choose_one(
1694+
# f"trtllm::nvfp4_gemm_allreduce::{runner_type}",
1695+
# [nvfp4_gemm_allreduce_runner],
1696+
# nvfp4_gemm_allreduce_runner.tuning_config,
1697+
# [act_fp4, weight, act_sf, weight_scale, alpha],
1698+
# )
1699+
1700+
best_tactic = -1
1701+
1702+
return nvfp4_gemm_allreduce_runner(
1703+
inputs=[act_fp4, weight, act_sf, weight_scale, alpha],
1704+
tactic=best_tactic)
1705+
1706+
1707+
@nvfp4_gemm_allreduce.register_fake
1708+
def _(
1709+
act_fp4: torch.Tensor,
1710+
weight: torch.Tensor,
1711+
act_sf: torch.Tensor,
1712+
weight_scale: torch.Tensor,
1713+
alpha: torch.Tensor,
1714+
output_dtype: torch.dtype,
1715+
tp_rank: int,
1716+
tp_group: List[int],
1717+
) -> torch.Tensor:
1718+
return act_fp4.new_empty((act_fp4.size(0), weight.size(0)),
1719+
dtype=output_dtype)

0 commit comments

Comments
 (0)