@@ -143,7 +143,7 @@ Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts, std:
143143 5. The order in which each token of this NPU is sent to various servers.
144144 size:[MAX_BS, serverNum]
145145 6. The order in which each token is sent to the expert.
146- size:[MAX_BS, numTopk ]
146+ size:[MAX_BS, numExpert ]
147147 7. The server offset of tokens received by each expert from this NPU.
148148 size:[numExpert, MAX_BS]
149149 */
@@ -157,6 +157,7 @@ Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts, std:
157157 this ->notify_send_data = notify_send_data;
158158 this ->send_token_idx_small = send_token_idx_small;
159159 this ->notify_send_data_size = notify_send_data_size;
160+ this ->tokens_per_rank = num_tokens_per_rank;
160161
161162 std::optional<torch::Tensor> num_tokens_per_rdma_rank = std::nullopt ;
162163 std::optional<EventHandle> output_event = std::nullopt ;
@@ -770,6 +771,8 @@ Buffer::internode_dispatch(
770771 at::empty ({num_experts, num_ranks, MAX_BATCH_SIZE}, at::dtype (at::kInt ).device (x.device ()));
771772 at::Tensor dst_offset_rank_token_idx =
772773 at::empty ({num_experts, num_ranks, MAX_BATCH_SIZE}, at::dtype (at::kInt ).device (x.device ()));
774+ at::Tensor token_idx_per_expert =
775+ at::empty ({num_ranks, num_experts}, at::dtype (at::kInt ).device (x.device ()));
773776 // The offsetInner for the current rank and the peer rank
774777 at::Tensor offset_inner = at::empty ({2 , MAX_BATCH_SIZE, num_experts}, at::dtype (at::kInt ).device (x.device ()));
775778 at::Tensor count_outer = at::empty ({MAX_BATCH_SIZE}, at::dtype (at::kInt ).device (x.device ()));
@@ -792,7 +795,7 @@ Buffer::internode_dispatch(
792795 local_rank_size, local_rank_id,
793796 send_data_offset, // A2 not use
794797 recv_data, token_server_idx, token_unique_per_server, ep_rank_token_cnt, recv_tokens_per_expert,
795- src_offset_rank_token_idx, dst_offset_rank_token_idx, offset_inner, count_outer, expand_idx,
798+ src_offset_rank_token_idx, dst_offset_rank_token_idx, token_idx_per_expert, offset_inner, count_outer, expand_idx,
796799 total_recv_token);
797800
798801 int total_count = total_recv_token.item <int >();
@@ -808,7 +811,7 @@ Buffer::internode_dispatch(
808811 }
809812
810813 EXEC_NPU_CMD (aclnnDispatchNormalA2, new_x, expert_ids, x_scales, xActiveMask, new_topk_weights, token_server_idx,
811- token_unique_per_server, ep_rank_token_cnt, src_offset_rank_token_idx, dst_offset_rank_token_idx,
814+ token_unique_per_server, ep_rank_token_cnt, src_offset_rank_token_idx, dst_offset_rank_token_idx, token_idx_per_expert,
812815 hcom_ep_name, num_ranks, rank, num_experts, hcom_ep_name, tp_size, tp_rank, expertShardType,
813816 sharedExpertNum, sharedExpertRankNum, quant_mode, global_bs, expertTokenNumsType, expandx_out,
814817 dynamic_scales_out, expand_idx, expertTokenNums, epRecvCount, expand_scales,
0 commit comments