forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathalltoallOp.cpp
More file actions
274 lines (235 loc) · 10.9 KB
/
alltoallOp.cpp
File metadata and controls
274 lines (235 loc) · 10.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/kernels/helixAllToAll.h"
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include "tensorrt_llm/thop/thUtils.h"
#include <vector>
TRTLLM_NAMESPACE_BEGIN
namespace torch_ext
{
#if ENABLE_MULTI_DEVICE
namespace
{
class AllToAllHelixOp
{
public:
AllToAllHelixOp(std::set<int> group)
: mGroup(std::move(group))
{
}
~AllToAllHelixOp() = default;
int initialize()
{
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank());
mNcclComm = getComm(mGroup);
TLLM_CHECK_WITH_INFO(mGroup.size() > 0, "group size should be greater than 0");
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank());
return 0;
}
std::vector<torch::Tensor> run(torch::TensorList input_list, torch::optional<int64_t> num_lists)
{
TLLM_CHECK_WITH_INFO(mNcclComm.get() != nullptr, "mNcclComm should be initialized before used");
auto num_lists_ = static_cast<int>(num_lists.value_or(1));
auto num_ranks = static_cast<int>(mGroup.size());
// note: ensures that input_list size > 0
TLLM_CHECK_WITH_INFO(static_cast<int>(input_list.size()) == num_ranks * num_lists_,
"input_list size should be equal to group size * num_lists");
for (auto const& input : input_list)
{
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
}
std::vector<torch::Tensor> output_list(static_cast<size_t>(num_lists_));
auto stream = at::cuda::getCurrentCUDAStream(input_list[0].get_device());
ncclGroupStart();
for (int il = 0; il < num_lists_; ++il)
{
auto off = il * num_ranks;
auto output_shape = input_list[off].sizes().vec();
output_shape.insert(output_shape.begin(), num_ranks);
auto output = torch::empty(output_shape, input_list[off].options());
output_list[il] = output;
auto type = tensorrt_llm::runtime::TorchUtils::dataType(input_list[off].scalar_type());
auto nccl_type = (*getDtypeMap())[type];
for (int r = 0; r < num_ranks; ++r)
{
auto const& input = input_list[off + r];
ncclSend(input.data_ptr(), input.numel(), nccl_type, r, *mNcclComm, stream);
ncclRecv(output[r].mutable_data_ptr(), output[r].numel(), nccl_type, r, *mNcclComm, stream);
}
}
NCCLCHECK_THROW(ncclGroupEnd());
return output_list;
}
private:
std::set<int> mGroup;
std::shared_ptr<ncclComm_t> mNcclComm;
};
} // namespace
#endif // ENABLE_MULTI_DEVICE
std::vector<torch::Tensor> alltoall_helix(
torch::TensorList input_list, torch::List<int64_t> group_, torch::optional<int64_t> num_lists)
{
#if ENABLE_MULTI_DEVICE
std::set<int> group;
for (int64_t rank : group_)
{
group.insert(static_cast<int>(rank));
}
AllToAllHelixOp op(group);
op.initialize();
return op.run(input_list, num_lists);
#else
return {};
#endif // ENABLE_MULTI_DEVICE
}
/**
* Helix All-to-All operation with two fields.
*
* Input tensors have shape [..., cp_size, kv_lora_rank] for partial_o and [...,
* cp_size, 2] for softmax_stats. The operation exchanges data along the cp_size
* dimension across all ranks.
*
* @param partial_o Field 0 tensor (half precision, shape [..., cp_size,
* kv_lora_rank])
* @param softmax_stats Field 1 tensor (float32, shape [..., cp_size, 2])
* @param workspace Workspace tensor (uint64, strided across ranks)
* @param cp_rank Current context parallel rank
* @param cp_size Total number of context parallel ranks
* @return tuple of (partial_o_out, softmax_stats_out) with same shapes as inputs
*/
std::tuple<torch::Tensor, torch::Tensor> alltoall_helix_native(
torch::Tensor partial_o, torch::Tensor softmax_stats, torch::Tensor workspace, int64_t cp_rank, int64_t cp_size)
{
// Input validation
CHECK_TH_CUDA(partial_o);
CHECK_TH_CUDA(softmax_stats);
CHECK_TH_CUDA(workspace);
CHECK_CONTIGUOUS(partial_o);
CHECK_CONTIGUOUS(softmax_stats);
// Type checks
TORCH_CHECK(partial_o.scalar_type() == at::ScalarType::Half || partial_o.scalar_type() == at::ScalarType::BFloat16,
"partial_o must be half or bfloat16");
CHECK_TYPE(softmax_stats, at::ScalarType::Float);
CHECK_TYPE(workspace, at::ScalarType::UInt64);
// Shape validation
TORCH_CHECK(partial_o.dim() >= 2, "partial_o must have at least 2 dimensions");
TORCH_CHECK(softmax_stats.dim() >= 2, "softmax_stats must have at least 2 dimensions");
TORCH_CHECK(
partial_o.dim() == softmax_stats.dim(), "partial_o and softmax_stats must have same number of dimensions");
// Get dimensions
int kv_lora_rank = partial_o.size(-1);
TORCH_CHECK(partial_o.size(-2) == cp_size && softmax_stats.size(-2) == cp_size,
"partial_o/softmax_stats second-to-last dimension must equal cp_size");
TORCH_CHECK(softmax_stats.size(-1) % 2 == 0 && softmax_stats.size(-1) >= 2,
"softmax_stats last dimension must be divisible by 2 (float2)");
bool allowVariableField1 = softmax_stats.size(-1) > 2;
// Check that leading dimensions match
for (int i = 0; i < partial_o.dim() - 2; i++)
{
TORCH_CHECK(partial_o.size(i) == softmax_stats.size(i),
"partial_o and softmax_stats must have matching dimensions except last two");
}
TORCH_CHECK(partial_o.size(-1) * partial_o.element_size() % 16 == 0, "partial_o must be aligned to 16 bytes");
TORCH_CHECK(workspace.dim() == 2, "workspace must be 2D (strided across ranks)");
TORCH_CHECK(workspace.size(0) == cp_size, "workspace must have cp_size rows");
// Calculate entry count (product of all dimensions before cp_size)
// This is the number of entries to process per peer rank
int entry_count = 1;
for (int i = 0; i < partial_o.dim() - 2; i++)
{
entry_count *= partial_o.size(i);
}
// Reshape to 3D: [entry_count, cp_size, feature_dim]
torch::Tensor partial_o_3d = partial_o.reshape({entry_count, cp_size, kv_lora_rank});
torch::Tensor softmax_stats_3d = softmax_stats.reshape({entry_count, cp_size, softmax_stats.size(-1)});
// Allocate output tensors (same shape as input)
torch::Tensor partial_o_out = torch::empty_like(partial_o);
torch::Tensor softmax_stats_out = torch::empty_like(softmax_stats);
torch::Tensor partial_o_out_3d = partial_o_out.reshape({entry_count, cp_size, kv_lora_rank});
torch::Tensor softmax_stats_out_3d = softmax_stats_out.reshape({entry_count, cp_size, softmax_stats.size(-1)});
// Setup parameters
tensorrt_llm::kernels::HelixAllToAllParams params;
// Field 0 (variable size half)
params.sendFields[0].dataPtr = reinterpret_cast<uint8_t*>(partial_o_3d.data_ptr());
params.sendFields[0].elementCount = kv_lora_rank;
params.sendFields[0].elementSize = partial_o.element_size();
params.sendFields[0].stride = partial_o_3d.stride(1) * partial_o.element_size();
params.recvFields[0].dataPtr = reinterpret_cast<uint8_t*>(partial_o_out_3d.data_ptr());
params.recvFields[0].elementCount = kv_lora_rank;
params.recvFields[0].elementSize = partial_o.element_size();
params.recvFields[0].stride = partial_o_out_3d.stride(1) * partial_o.element_size();
// Field 1 (single float2)
params.sendFields[1].dataPtr = reinterpret_cast<uint8_t*>(softmax_stats_3d.data_ptr<float>());
params.sendFields[1].elementCount = softmax_stats.size(-1);
params.sendFields[1].elementSize = softmax_stats.element_size();
params.sendFields[1].stride = softmax_stats_3d.stride(1) * softmax_stats.element_size();
params.recvFields[1].dataPtr = reinterpret_cast<uint8_t*>(softmax_stats_out_3d.data_ptr<float>());
params.recvFields[1].elementCount = softmax_stats.size(-1);
params.recvFields[1].elementSize = softmax_stats.element_size();
params.recvFields[1].stride = softmax_stats_out_3d.stride(1) * softmax_stats.element_size();
// Entry count and workspace
params.entryCount = entry_count;
params.workspace = workspace.data_ptr<uint64_t>();
params.workspaceStrideInU64 = workspace.stride(0);
// CP info
params.cpRank = cp_rank;
params.cpSize = cp_size;
params.channelCount = 0; // auto-compute
params.maxChannelCount = tensorrt_llm::kernels::computeHelixMaxChannelCount(cp_size);
// Launch kernel
auto stream = at::cuda::getCurrentCUDAStream();
tensorrt_llm::kernels::launchHelixAllToAll(params, allowVariableField1, stream);
return std::make_tuple(partial_o_out, softmax_stats_out);
}
/**
* Initialize workspace for helix all-to-all
*/
void initialize_helix_workspace(torch::Tensor workspace, int64_t cp_rank, int64_t cp_size)
{
CHECK_TH_CUDA(workspace);
CHECK_TYPE(workspace, at::ScalarType::UInt64);
TORCH_CHECK(workspace.dim() == 2, "workspace must be 2D");
TORCH_CHECK(workspace.size(0) == cp_size, "workspace must have cp_size rows");
TORCH_CHECK(cp_rank >= 0 && cp_rank < cp_size, "cp_rank must be in [0, cp_size)");
auto stream = at::cuda::getCurrentCUDAStream();
uint64_t* global_workspace_ptr = workspace.data_ptr<uint64_t>();
uint64_t* local_workspace_ptr = workspace[cp_rank].data_ptr<uint64_t>();
TORCH_CHECK(local_workspace_ptr == global_workspace_ptr + cp_rank * workspace.stride(0),
"local_workspace_ptr must be at the correct offset in the global "
"workspace");
tensorrt_llm::kernels::initializeHelixWorkspace(local_workspace_ptr, cp_size, stream);
}
} // namespace torch_ext
TRTLLM_NAMESPACE_END
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def("alltoall_helix(Tensor[] input_list, int[] group, int? num_lists) -> Tensor[]");
m.def(
"alltoall_helix_native(Tensor partial_o, Tensor softmax_stats, Tensor(a!) workspace, int "
"cp_rank, int cp_size) -> (Tensor, Tensor)");
m.def(
"initialize_helix_workspace(Tensor(a!) workspace, int cp_rank, int cp_size) "
"-> ()");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("alltoall_helix", &tensorrt_llm::torch_ext::alltoall_helix);
m.impl("alltoall_helix_native", &tensorrt_llm::torch_ext::alltoall_helix_native);
m.impl("initialize_helix_workspace", &tensorrt_llm::torch_ext::initialize_helix_workspace);
}