Skip to content

Commit 74d7798

Browse files
committed
add fp4 gemm + allreduce
Signed-off-by: benzh <[email protected]>
1 parent e06c582 commit 74d7798

File tree

5 files changed

+497
-10
lines changed

5 files changed

+497
-10
lines changed

cpp/tensorrt_llm/thop/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ add_library(
104104
loraOp.cpp
105105
finegrained_mixed_dtype_gemm_thop.cpp
106106
tinygemm2.cpp
107-
dsv3RopeOp.cpp)
107+
dsv3RopeOp.cpp
108+
fusedGemmAllreduceOp.cpp)
108109
set_property(TARGET th_common PROPERTY POSITION_INDEPENDENT_CODE ON)
109110
target_link_libraries(
110111
th_common PRIVATE ${TORCH_LIBRARIES} th_utils ${Python3_LIBRARIES}
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
/*
2+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "allreduce_gemm_runner.h"
18+
#include "cutlass_extensions/gemm_configs.h"
19+
20+
#include "tensorrt_llm/common/cudaUtils.h"
21+
#include "tensorrt_llm/kernels/cutlass_kernels/include/allreduce_gemm_runner.h"
22+
#include "tensorrt_llm/runtime/ipcNvlsMemory.h"
23+
#include "tensorrt_llm/thop/thUtils.h"
24+
25+
#include <ATen/cuda/EmptyTensor.h>
26+
#include <ATen/native/cuda/Resize.h>
27+
28+
#include <cstddef>
29+
#include <cuda_fp16.h>
30+
31+
#include <cstdint>
32+
#include <functional>
33+
#include <type_traits>
34+
#include <vector>
35+
36+
using tensorrt_llm::kernels::opened_cutlass_kernels::GemmAllReduceImplRunner;
37+
using tensorrt_llm::kernels::opened_cutlass_kernels::GemmAllReduceImplInterface;
38+
using tensorrt_llm::kernels::opened_cutlass_kernels::GemmTypes;
39+
using tensorrt_llm::kernels::opened_cutlass_kernels::PersistentWorkspaceInterface;
40+
41+
namespace torch_ext
42+
{
43+
PersistentWorkspaceInterface* getWorkspace(
44+
GemmAllReduceImplInterface* runner, GemmAllReduceImplInterface::ProblemArgs const& problem)
45+
{
46+
thread_local std::shared_ptr<PersistentWorkspaceInterface> curWorkspace;
47+
thread_local size_t curWorkspaceSize = 0;
48+
auto newWorkspace = runner->getPersistentWorkspace(problem);
49+
if (newWorkspace->size() > curWorkspaceSize)
50+
{
51+
newWorkspace->allocate();
52+
curWorkspaceSize = newWorkspace->size();
53+
curWorkspace = newWorkspace;
54+
}
55+
return curWorkspace.get();
56+
}
57+
58+
class Fp4GemmAllreduceRunner : public torch::CustomClassHolder
59+
{
60+
public:
61+
explicit Fp4GemmAllreduceRunner(at::ScalarType outputDtype, int64_t rank, torch::List<int64_t> group)
62+
: mOutputDtype(outputDtype)
63+
, mRank(rank)
64+
{
65+
for (int64_t rank : group)
66+
{
67+
mGroup.insert(static_cast<int>(rank));
68+
}
69+
70+
if (outputDtype == at::ScalarType::Half)
71+
{
72+
using Traits = GemmTypes<cutlass::float_e2m1_t, cutlass::float_e2m1_t, cutlass::half_t, cutlass::half_t,
73+
cutlass::float_ue4m3_t, cutlass::float_ue4m3_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor,
74+
cutlass::layout::RowMajor, cutlass::layout::RowMajor>;
75+
mRunner = std::make_shared<GemmAllReduceImplRunner<Traits>>();
76+
}
77+
else if (outputDtype == at::ScalarType::BFloat16)
78+
{
79+
using Traits = GemmTypes<cutlass::float_e2m1_t, cutlass::float_e2m1_t, cutlass::bfloat16_t,
80+
cutlass::bfloat16_t, cutlass::float_ue4m3_t, cutlass::float_ue4m3_t, cutlass::layout::RowMajor,
81+
cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, cutlass::layout::RowMajor>;
82+
mRunner = std::make_shared<GemmAllReduceImplRunner<Traits>>();
83+
}
84+
else
85+
{
86+
C10_THROW_ERROR(NotImplementedError, "Unsupported input or output dtype");
87+
}
88+
89+
mConfigs = mRunner->getSupportedLaunchConfigs();
90+
}
91+
92+
at::Tensor runGemm(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& mat1Scale,
93+
at::Tensor const& mat2Scale, at::Tensor const& alpha, int64_t configIdx) const
94+
{
95+
if (configIdx < 0)
96+
configIdx = 0;
97+
98+
assert(configIdx < int64_t(mConfigs.size()));
99+
const int64_t M = mat1.size(0);
100+
const int64_t N = mat2.size(0);
101+
const int64_t K = mat1.size(1) * 2;
102+
103+
GemmAllReduceImplInterface::ProblemArgs problemArgs;
104+
problemArgs.argProblemShape(M, N, K, 1);
105+
problemArgs.argA(mat1.data_ptr());
106+
problemArgs.argB(mat2.data_ptr());
107+
problemArgs.argAScale(mat1Scale.data_ptr());
108+
problemArgs.argBScale(mat2Scale.data_ptr());
109+
problemArgs.argC(nullptr);
110+
problemArgs.argAlphaPtr(alpha.const_data_ptr<float>());
111+
problemArgs.argBeta(0.f);
112+
problemArgs.argRanks(mRank, mGroup);
113+
problemArgs.argLaunchConfig(mConfigs[configIdx]);
114+
115+
size_t dSize = M * N * c10::elementSize(mOutputDtype);
116+
auto handle = tensorrt_llm::runtime::ipcNvlsAllocate(dSize, mGroup);
117+
problemArgs.argD((void*) handle->uc_ptr, (void*) handle->mc_ptr, (void**) handle->ipc_uc_ptrs.data());
118+
119+
auto workspace = getWorkspace(mRunner.get(), problemArgs);
120+
problemArgs.argWorkspace(workspace);
121+
122+
auto stream = at::cuda::getCurrentCUDAStream(mat1.get_device());
123+
mRunner->run(problemArgs, stream);
124+
125+
auto options = mat1.options().dtype(mOutputDtype);
126+
auto deleter = [=](void* unused) { ipcNvlsFree(handle); };
127+
auto D = at::from_blob((void*) handle->uc_ptr, {M, N}, {N, 1}, deleter, options);
128+
return D;
129+
}
130+
131+
int64_t getNumConfigs() const
132+
{
133+
return static_cast<int64_t>(mConfigs.size());
134+
}
135+
136+
private:
137+
at::ScalarType mOutputDtype;
138+
int mRank;
139+
std::set<int> mGroup;
140+
std::shared_ptr<GemmAllReduceImplInterface> mRunner{nullptr};
141+
std::vector<GemmAllReduceImplInterface::LaunchConfig> mConfigs;
142+
};
143+
144+
} // namespace torch_ext
145+
146+
TORCH_LIBRARY_FRAGMENT(trtllm, m)
147+
{
148+
m.class_<torch_ext::Fp4GemmAllreduceRunner>("Fp4GemmAllreduceRunner")
149+
.def(torch::init<at::ScalarType, int64_t, torch::List<int64_t>>())
150+
.def("run_gemm", &torch_ext::Fp4GemmAllreduceRunner::runGemm)
151+
.def("get_num_configs", &torch_ext::Fp4GemmAllreduceRunner::getNumConfigs);
152+
}

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1614,3 +1614,96 @@ def record_stream(tensor: torch.Tensor, stream_id: int) -> None:
16141614
stream = get_stream(stream_id)
16151615
assert stream is not None
16161616
tensor.record_stream(stream)
1617+
1618+
1619+
class Fp4GemmAllreduceRunner(TunableRunner):
1620+
runner_dict = dict()
1621+
tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec(
1622+
0, 0, get_last_power_of_2_num_tokens_buckets,
1623+
last_positive_power_of_2), ),
1624+
constraint_specs=(ConstraintSpec(
1625+
2, 0, fp4_scale_infer_shape), ))
1626+
1627+
def __init__(
1628+
self,
1629+
output_dtype: torch.dtype,
1630+
tp_rank: int,
1631+
tp_group: List[int],
1632+
):
1633+
self.output_dtype = output_dtype
1634+
self.tp_rank = tp_rank
1635+
self.tp_group_str = '-'.join(str(g) for g in tp_group)
1636+
instance_key = (output_dtype, self.tp_group_str)
1637+
if instance_key not in Fp4GemmAllreduceRunner.runner_dict:
1638+
Fp4GemmAllreduceRunner.runner_dict[
1639+
instance_key] = torch.classes.trtllm.Fp4GemmAllreduceRunner(
1640+
output_dtype, tp_rank, tp_group)
1641+
self.fp4_gemm_all_reduce_runner = Fp4GemmAllreduceRunner.runner_dict[
1642+
instance_key]
1643+
1644+
def unique_id(self):
1645+
return (self.output_dtype, self.tp_group_str)
1646+
1647+
def get_valid_tactics(self, inputs: List[torch.Tensor],
1648+
profile: OptimizationProfile, **kwargs) -> List[int]:
1649+
return list(range(self.fp4_gemm_all_reduce_runner.get_num_configs()))
1650+
1651+
def forward(
1652+
self,
1653+
inputs: List[torch.Tensor],
1654+
tactic: int = 0,
1655+
) -> torch.Tensor:
1656+
mat1, mat2, mat1_scale, mat2_scale, global_scale = inputs
1657+
return self.fp4_gemm_all_reduce_runner.run_gemm(
1658+
mat1,
1659+
mat2,
1660+
mat1_scale,
1661+
mat2_scale,
1662+
global_scale,
1663+
tactic,
1664+
)
1665+
1666+
1667+
@torch.library.custom_op("trtllm::nvfp4_gemm_allreduce", mutates_args=())
1668+
def nvfp4_gemm_allreduce(
1669+
act_fp4: torch.Tensor,
1670+
weight: torch.Tensor,
1671+
act_sf: torch.Tensor,
1672+
weight_scale: torch.Tensor,
1673+
alpha: torch.Tensor,
1674+
output_dtype: torch.dtype,
1675+
tp_rank: int,
1676+
tp_group: List[int],
1677+
) -> torch.Tensor:
1678+
tuner = AutoTuner.get()
1679+
1680+
# Use Cutlass runner with predefined configs
1681+
nvfp4_gemm_allreduce_runner = Fp4GemmAllreduceRunner(
1682+
output_dtype, tp_rank, tp_group)
1683+
1684+
runner_type = type(nvfp4_gemm_allreduce_runner).__name__
1685+
_, best_tactic = tuner.choose_one(
1686+
f"trtllm::nvfp4_gemm_allreduce::{runner_type}",
1687+
[nvfp4_gemm_allreduce_runner],
1688+
nvfp4_gemm_allreduce_runner.tuning_config,
1689+
[act_fp4, weight, act_sf, weight_scale, alpha],
1690+
)
1691+
1692+
return nvfp4_gemm_allreduce_runner(
1693+
inputs=[act_fp4, weight, act_sf, weight_scale, alpha],
1694+
tactic=best_tactic)
1695+
1696+
1697+
@nvfp4_gemm_allreduce.register_fake
1698+
def _(
1699+
act_fp4: torch.Tensor,
1700+
weight: torch.Tensor,
1701+
act_sf: torch.Tensor,
1702+
weight_scale: torch.Tensor,
1703+
alpha: torch.Tensor,
1704+
output_dtype: torch.dtype,
1705+
tp_rank: int,
1706+
tp_group: List[int],
1707+
) -> torch.Tensor:
1708+
return act_fp4.new_empty((act_fp4.size(0), weight.size(0)),
1709+
dtype=output_dtype)

0 commit comments

Comments
 (0)