Skip to content

Commit d813fba

Browse files
committed
add fp4 gemm + allreduce
Signed-off-by: benzh <[email protected]>
1 parent 137713a commit d813fba

File tree

5 files changed

+490
-10
lines changed

5 files changed

+490
-10
lines changed

cpp/tensorrt_llm/thop/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ add_library(
104104
loraOp.cpp
105105
finegrained_mixed_dtype_gemm_thop.cpp
106106
tinygemm2.cpp
107-
dsv3RopeOp.cpp)
107+
dsv3RopeOp.cpp
108+
fusedGemmAllreduceOp.cpp)
108109
set_property(TARGET th_common PROPERTY POSITION_INDEPENDENT_CODE ON)
109110
target_link_libraries(
110111
th_common PRIVATE ${TORCH_LIBRARIES} th_utils ${Python3_LIBRARIES}
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
/*
2+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "cutlass_extensions/gemm_configs.h"
18+
19+
#include "tensorrt_llm/common/cudaUtils.h"
20+
#include "tensorrt_llm/kernels/cutlass_kernels/include/allreduce_gemm_runner.h"
21+
#include "tensorrt_llm/runtime/ipcNvlsMemory.h"
22+
#include "tensorrt_llm/thop/thUtils.h"
23+
24+
#include <cstddef>
25+
#include <cuda_fp16.h>
26+
27+
#include <cstdint>
28+
#include <functional>
29+
#include <type_traits>
30+
#include <vector>
31+
32+
using tensorrt_llm::kernels::opened_cutlass_kernels::GemmAllReduceImplRunner;
33+
using tensorrt_llm::kernels::opened_cutlass_kernels::GemmAllReduceImplInterface;
34+
using tensorrt_llm::kernels::opened_cutlass_kernels::GemmTypes;
35+
using tensorrt_llm::kernels::opened_cutlass_kernels::PersistentWorkspaceInterface;
36+
37+
namespace torch_ext
38+
{
39+
PersistentWorkspaceInterface* getWorkspace(
40+
GemmAllReduceImplInterface* runner, GemmAllReduceImplInterface::ProblemArgs const& problem)
41+
{
42+
thread_local std::shared_ptr<PersistentWorkspaceInterface> curWorkspace;
43+
thread_local size_t curWorkspaceSize = 0;
44+
auto newWorkspace = runner->getPersistentWorkspace(problem);
45+
if (newWorkspace->size() > curWorkspaceSize)
46+
{
47+
TLLM_LOG_WARNING(
48+
"Fp4GemmAllreduceRunner workspace is not enough, allocate new workspace", newWorkspace->size());
49+
newWorkspace->allocate();
50+
curWorkspaceSize = newWorkspace->size();
51+
curWorkspace = newWorkspace;
52+
}
53+
return curWorkspace.get();
54+
}
55+
56+
class Fp4GemmAllreduceRunner : public torch::CustomClassHolder
57+
{
58+
public:
59+
explicit Fp4GemmAllreduceRunner(at::ScalarType outputDtype, int64_t rank, torch::List<int64_t> group)
60+
: mOutputDtype(outputDtype)
61+
, mRank(rank)
62+
{
63+
for (int64_t rank : group)
64+
{
65+
mGroup.insert(static_cast<int>(rank));
66+
}
67+
68+
if (outputDtype == at::ScalarType::Half)
69+
{
70+
using Traits = GemmTypes<cutlass::float_e2m1_t, cutlass::float_e2m1_t, cutlass::half_t, cutlass::half_t,
71+
cutlass::float_ue4m3_t, cutlass::float_ue4m3_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor,
72+
cutlass::layout::RowMajor, cutlass::layout::RowMajor>;
73+
mRunner = std::make_shared<GemmAllReduceImplRunner<Traits>>();
74+
}
75+
else if (outputDtype == at::ScalarType::BFloat16)
76+
{
77+
using Traits = GemmTypes<cutlass::float_e2m1_t, cutlass::float_e2m1_t, cutlass::bfloat16_t,
78+
cutlass::bfloat16_t, cutlass::float_ue4m3_t, cutlass::float_ue4m3_t, cutlass::layout::RowMajor,
79+
cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, cutlass::layout::RowMajor>;
80+
mRunner = std::make_shared<GemmAllReduceImplRunner<Traits>>();
81+
}
82+
else
83+
{
84+
C10_THROW_ERROR(NotImplementedError, "Unsupported input or output dtype");
85+
}
86+
87+
mConfigs = mRunner->getSupportedLaunchConfigs();
88+
}
89+
90+
at::Tensor runGemm(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& mat1Scale,
91+
at::Tensor const& mat2Scale, at::Tensor const& alpha, int64_t configIdx) const
92+
{
93+
if (configIdx < 0)
94+
configIdx = 0;
95+
96+
TORCH_CHECK(configIdx < int64_t(mConfigs.size()), "configIdx out of bounds");
97+
const int64_t M = mat1.size(0);
98+
const int64_t N = mat2.size(0);
99+
const int64_t K = mat1.size(1) * 2;
100+
101+
GemmAllReduceImplInterface::ProblemArgs problemArgs;
102+
problemArgs.argProblemShape(M, N, K, 1);
103+
problemArgs.argA(mat1.data_ptr());
104+
problemArgs.argB(mat2.data_ptr());
105+
problemArgs.argAScale(mat1Scale.data_ptr());
106+
problemArgs.argBScale(mat2Scale.data_ptr());
107+
problemArgs.argC(nullptr);
108+
problemArgs.argAlphaPtr(reinterpret_cast<float const*>(alpha.const_data_ptr()));
109+
problemArgs.argBeta(0.f);
110+
problemArgs.argRanks(mRank, mGroup);
111+
problemArgs.argLaunchConfig(mConfigs[configIdx]);
112+
113+
size_t dSize = M * N * c10::elementSize(mOutputDtype);
114+
auto handle = tensorrt_llm::runtime::ipcNvlsAllocate(dSize, mGroup);
115+
problemArgs.argD((void*) handle->uc_ptr, (void*) handle->mc_ptr, (void**) handle->ipc_uc_ptrs.data());
116+
117+
auto workspace = getWorkspace(mRunner.get(), problemArgs);
118+
problemArgs.argWorkspace(workspace);
119+
120+
auto stream = at::cuda::getCurrentCUDAStream(mat1.get_device());
121+
mRunner->run(problemArgs, stream);
122+
123+
auto options = mat1.options().dtype(mOutputDtype);
124+
auto deleter = [=](void* unused) { ipcNvlsFree(handle); };
125+
auto D = at::from_blob((void*) handle->uc_ptr, {M, N}, {N, 1}, deleter, options);
126+
return D;
127+
}
128+
129+
int64_t getNumConfigs() const
130+
{
131+
return static_cast<int64_t>(mConfigs.size());
132+
}
133+
134+
private:
135+
at::ScalarType mOutputDtype;
136+
int mRank;
137+
std::set<int> mGroup;
138+
std::shared_ptr<GemmAllReduceImplInterface> mRunner{nullptr};
139+
std::vector<GemmAllReduceImplInterface::LaunchConfig> mConfigs;
140+
};
141+
142+
} // namespace torch_ext
143+
144+
TORCH_LIBRARY_FRAGMENT(trtllm, m)
145+
{
146+
m.class_<torch_ext::Fp4GemmAllreduceRunner>("Fp4GemmAllreduceRunner")
147+
.def(torch::init<at::ScalarType, int64_t, torch::List<int64_t>>())
148+
.def("run_gemm", &torch_ext::Fp4GemmAllreduceRunner::runGemm)
149+
.def("get_num_configs", &torch_ext::Fp4GemmAllreduceRunner::getNumConfigs);
150+
}

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1621,3 +1621,96 @@ def record_stream(tensor: torch.Tensor, stream_id: int) -> None:
16211621
stream = get_stream(stream_id)
16221622
assert stream is not None
16231623
tensor.record_stream(stream)
1624+
1625+
1626+
class Fp4GemmAllreduceRunner(TunableRunner):
1627+
runner_dict = dict()
1628+
tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec(
1629+
0, 0, get_last_power_of_2_num_tokens_buckets,
1630+
last_positive_power_of_2), ),
1631+
constraint_specs=(ConstraintSpec(
1632+
2, 0, fp4_scale_infer_shape), ))
1633+
1634+
def __init__(
1635+
self,
1636+
output_dtype: torch.dtype,
1637+
tp_rank: int,
1638+
tp_group: List[int],
1639+
):
1640+
self.output_dtype = output_dtype
1641+
self.tp_rank = tp_rank
1642+
self.tp_group_str = '-'.join(str(g) for g in tp_group)
1643+
instance_key = (output_dtype, self.tp_group_str)
1644+
if instance_key not in Fp4GemmAllreduceRunner.runner_dict:
1645+
Fp4GemmAllreduceRunner.runner_dict[
1646+
instance_key] = torch.classes.trtllm.Fp4GemmAllreduceRunner(
1647+
output_dtype, tp_rank, tp_group)
1648+
self.fp4_gemm_all_reduce_runner = Fp4GemmAllreduceRunner.runner_dict[
1649+
instance_key]
1650+
1651+
def unique_id(self):
1652+
return (self.output_dtype, self.tp_group_str)
1653+
1654+
def get_valid_tactics(self, inputs: List[torch.Tensor],
1655+
profile: OptimizationProfile, **kwargs) -> List[int]:
1656+
return list(range(self.fp4_gemm_all_reduce_runner.get_num_configs()))
1657+
1658+
def forward(
1659+
self,
1660+
inputs: List[torch.Tensor],
1661+
tactic: int = 0,
1662+
) -> torch.Tensor:
1663+
mat1, mat2, mat1_scale, mat2_scale, global_scale = inputs
1664+
return self.fp4_gemm_all_reduce_runner.run_gemm(
1665+
mat1,
1666+
mat2,
1667+
mat1_scale,
1668+
mat2_scale,
1669+
global_scale,
1670+
tactic,
1671+
)
1672+
1673+
1674+
@torch.library.custom_op("trtllm::nvfp4_gemm_allreduce", mutates_args=())
1675+
def nvfp4_gemm_allreduce(
1676+
act_fp4: torch.Tensor,
1677+
weight: torch.Tensor,
1678+
act_sf: torch.Tensor,
1679+
weight_scale: torch.Tensor,
1680+
alpha: torch.Tensor,
1681+
output_dtype: torch.dtype,
1682+
tp_rank: int,
1683+
tp_group: List[int],
1684+
) -> torch.Tensor:
1685+
tuner = AutoTuner.get()
1686+
1687+
# Use Cutlass runner with predefined configs
1688+
nvfp4_gemm_allreduce_runner = Fp4GemmAllreduceRunner(
1689+
output_dtype, tp_rank, tp_group)
1690+
1691+
runner_type = type(nvfp4_gemm_allreduce_runner).__name__
1692+
_, best_tactic = tuner.choose_one(
1693+
f"trtllm::nvfp4_gemm_allreduce::{runner_type}",
1694+
[nvfp4_gemm_allreduce_runner],
1695+
nvfp4_gemm_allreduce_runner.tuning_config,
1696+
[act_fp4, weight, act_sf, weight_scale, alpha],
1697+
)
1698+
1699+
return nvfp4_gemm_allreduce_runner(
1700+
inputs=[act_fp4, weight, act_sf, weight_scale, alpha],
1701+
tactic=best_tactic)
1702+
1703+
1704+
@nvfp4_gemm_allreduce.register_fake
1705+
def _(
1706+
act_fp4: torch.Tensor,
1707+
weight: torch.Tensor,
1708+
act_sf: torch.Tensor,
1709+
weight_scale: torch.Tensor,
1710+
alpha: torch.Tensor,
1711+
output_dtype: torch.dtype,
1712+
tp_rank: int,
1713+
tp_group: List[int],
1714+
) -> torch.Tensor:
1715+
return act_fp4.new_empty((act_fp4.size(0), weight.size(0)),
1716+
dtype=output_dtype)

0 commit comments

Comments
 (0)