@@ -1901,6 +1901,7 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* co
19011901 const int * rdma_channel_prefix_matrix,
19021902 const int * rdma_rank_prefix_sum,
19031903 const int * gbl_channel_prefix_matrix,
1904+ const int * num_recv_tokens_ptr,
19041905 int num_tokens,
19051906 int num_combined_tokens,
19061907 int hidden,
@@ -2005,7 +2006,12 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* co
20052006 if (lane_id < kNumRDMARanks ) {
20062007 int prefix_idx = (lane_id * NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_channels + channel_id;
20072008 token_start_idx = gbl_channel_prefix_matrix[prefix_idx];
2008- token_end_idx = (prefix_idx == num_channels * num_ranks - 1 ) ? num_tokens : gbl_channel_prefix_matrix[prefix_idx + 1 ];
2009+ // The last `(rdma, nvl, channel)` slot has no `+1` neighbor, so its upper bound has to come
2010+ // from the real recv-token total. When `num_recv_tokens_ptr` is supplied (HT dispatch with
2011+ // `num_worst_tokens > 0` pads `x` to the worst-case size), read the device-side total to
2012+ // avoid sending into the padding region. Otherwise fall back to the input shape.
2013+ const int real_num_tokens = num_recv_tokens_ptr != nullptr ? __ldg (num_recv_tokens_ptr) : num_tokens;
2014+ token_end_idx = (prefix_idx == num_channels * num_ranks - 1 ) ? real_num_tokens : gbl_channel_prefix_matrix[prefix_idx + 1 ];
20092015 }
20102016 __syncwarp ();
20112017
@@ -2513,6 +2519,7 @@ void combine(cudaDataType_t type,
25132519 const int * rdma_channel_prefix_matrix,
25142520 const int * rdma_rank_prefix_sum,
25152521 const int * gbl_channel_prefix_matrix,
2522+ const int * num_recv_tokens_ptr,
25162523 int num_tokens,
25172524 int num_combined_tokens,
25182525 int hidden,
@@ -2568,6 +2575,7 @@ void combine(cudaDataType_t type,
25682575 rdma_channel_prefix_matrix, \
25692576 rdma_rank_prefix_sum, \
25702577 gbl_channel_prefix_matrix, \
2578+ num_recv_tokens_ptr, \
25712579 num_tokens, \
25722580 num_combined_tokens, \
25732581 hidden, \
0 commit comments