Skip to content

Commit fb0abad

Browse files
committed
fix ht combine after removing busy-wait
1 parent 38efa01 commit fb0abad

5 files changed

Lines changed: 22 additions & 1 deletion

File tree

3rdparty/deep_ep/deep_ep.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1428,6 +1428,7 @@ Buffer::internode_combine(const Tensor& x,
14281428
const Tensor& gbl_channel_prefix_matrix,
14291429
Tensor& combined_rdma_head,
14301430
Tensor& combined_nvl_head,
1431+
const int* num_recv_tokens_ptr,
14311432
const Config& config)
14321433
{
14331434
const int num_channels = config.num_sms / 2;
@@ -1545,6 +1546,7 @@ Buffer::internode_combine(const Tensor& x,
15451546
rdma_channel_prefix_matrix.data<int>(),
15461547
rdma_rank_prefix_sum.data<int>(),
15471548
gbl_channel_prefix_matrix.data<int>(),
1549+
num_recv_tokens_ptr,
15481550
num_tokens,
15491551
num_combined_tokens,
15501552
hidden,

3rdparty/deep_ep/deep_ep.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ class Buffer {
263263
const Tensor& gbl_channel_prefix_matrix,
264264
Tensor& combined_rdma_head,
265265
Tensor& combined_nvl_head,
266+
const int* num_recv_tokens_ptr,
266267
const Config& config);
267268

268269
Config get_dispatch_config();

3rdparty/deep_ep/kernels/api.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ void combine(cudaDataType_t type,
278278
const int* rdma_channel_prefix_matrix,
279279
const int* rdma_rank_prefix_sum,
280280
const int* gbl_channel_prefix_matrix,
281+
const int* num_recv_tokens_ptr,
281282
int num_tokens,
282283
int num_combined_tokens,
283284
int hidden,

3rdparty/deep_ep/kernels/internode.cu

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1901,6 +1901,7 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* co
19011901
const int* rdma_channel_prefix_matrix,
19021902
const int* rdma_rank_prefix_sum,
19031903
const int* gbl_channel_prefix_matrix,
1904+
const int* num_recv_tokens_ptr,
19041905
int num_tokens,
19051906
int num_combined_tokens,
19061907
int hidden,
@@ -2005,7 +2006,12 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* co
20052006
if (lane_id < kNumRDMARanks) {
20062007
int prefix_idx = (lane_id * NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_channels + channel_id;
20072008
token_start_idx = gbl_channel_prefix_matrix[prefix_idx];
2008-
token_end_idx = (prefix_idx == num_channels * num_ranks - 1) ? num_tokens : gbl_channel_prefix_matrix[prefix_idx + 1];
2009+
// The last `(rdma, nvl, channel)` slot has no `+1` neighbor, so its upper bound has to come
2010+
// from the real recv-token total. When `num_recv_tokens_ptr` is supplied (HT dispatch with
2011+
// `num_worst_tokens > 0` pads `x` to the worst-case size), read the device-side total to
2012+
// avoid sending into the padding region. Otherwise fall back to the input shape.
2013+
const int real_num_tokens = num_recv_tokens_ptr != nullptr ? __ldg(num_recv_tokens_ptr) : num_tokens;
2014+
token_end_idx = (prefix_idx == num_channels * num_ranks - 1) ? real_num_tokens : gbl_channel_prefix_matrix[prefix_idx + 1];
20092015
}
20102016
__syncwarp();
20112017

@@ -2513,6 +2519,7 @@ void combine(cudaDataType_t type,
25132519
const int* rdma_channel_prefix_matrix,
25142520
const int* rdma_rank_prefix_sum,
25152521
const int* gbl_channel_prefix_matrix,
2522+
const int* num_recv_tokens_ptr,
25162523
int num_tokens,
25172524
int num_combined_tokens,
25182525
int hidden,
@@ -2568,6 +2575,7 @@ void combine(cudaDataType_t type,
25682575
rdma_channel_prefix_matrix, \
25692576
rdma_rank_prefix_sum, \
25702577
gbl_channel_prefix_matrix, \
2578+
num_recv_tokens_ptr, \
25712579
num_tokens, \
25722580
num_combined_tokens, \
25732581
hidden, \

src/turbomind/comm/nccl/nccl_ep.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,14 @@ void NcclCommImpl::Combine(const EpCombineInput& input, EpCombineOutput& output,
345345
auto combined_rdma_head = input.handle[8];
346346
auto combined_nvl_head = input.handle[9];
347347

348+
// Real recv-token total lives at the last slot of `recv_gbl_rank_prefix_sum`. The
349+
// internode combine kernel needs it to bound the very last (rdma, nvl, channel)
350+
// task range when HT dispatch was called with `num_worst_tokens > 0` (which pads
351+
// `input.x` past the real total).
352+
auto recv_gbl_rank_prefix_sum = input.handle[6];
353+
const int* num_recv_tokens_ptr =
354+
recv_gbl_rank_prefix_sum.data<int>() + recv_gbl_rank_prefix_sum.shape(0) - 1;
355+
348356
auto [combined_x, combined_topk_weights] = buffer_->internode_combine(input.x,
349357
std::nullopt,
350358
std::nullopt,
@@ -356,6 +364,7 @@ void NcclCommImpl::Combine(const EpCombineInput& input, EpCombineOutput& output,
356364
gbl_channel_prefix_matrix,
357365
combined_rdma_head,
358366
combined_nvl_head,
367+
num_recv_tokens_ptr,
359368
config);
360369
sync_check_cuda_error();
361370
output.out_x = combined_x;

0 commit comments

Comments
 (0)