Skip to content

Commit a53f55d

Browse files
committed
First draft manuscript
rebuild tokenIdxPerExpert
1 parent cf0f73f commit a53f55d

14 files changed

+300
-37
lines changed

csrc/deepep/deep_ep.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts, std:
143143
5. The order in which each token of this NPU is sent to various servers.
144144
size:[MAX_BS, serverNum]
145145
6. The order in which each token is sent to the expert.
146-
size:[MAX_BS, numTopk]
146+
size:[MAX_BS, numExpert]
147147
7. The server offset of tokens received by each expert from this NPU.
148148
size:[numExpert, MAX_BS]
149149
*/
@@ -157,6 +157,7 @@ Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts, std:
157157
this->notify_send_data = notify_send_data;
158158
this->send_token_idx_small = send_token_idx_small;
159159
this->notify_send_data_size = notify_send_data_size;
160+
this->tokens_per_rank = num_tokens_per_rank;
160161

161162
std::optional<torch::Tensor> num_tokens_per_rdma_rank = std::nullopt;
162163
std::optional<EventHandle> output_event = std::nullopt;
@@ -770,6 +771,8 @@ Buffer::internode_dispatch(
770771
at::empty({num_experts, num_ranks, MAX_BATCH_SIZE}, at::dtype(at::kInt).device(x.device()));
771772
at::Tensor dst_offset_rank_token_idx =
772773
at::empty({num_experts, num_ranks, MAX_BATCH_SIZE}, at::dtype(at::kInt).device(x.device()));
774+
at::Tensor token_idx_per_expert =
775+
at::empty({num_ranks, num_experts}, at::dtype(at::kInt).device(x.device()));
773776
// The offsetInner for the current rank and the peer rank
774777
at::Tensor offset_inner = at::empty({2, MAX_BATCH_SIZE, num_experts}, at::dtype(at::kInt).device(x.device()));
775778
at::Tensor count_outer = at::empty({MAX_BATCH_SIZE}, at::dtype(at::kInt).device(x.device()));
@@ -792,7 +795,7 @@ Buffer::internode_dispatch(
792795
local_rank_size, local_rank_id,
793796
send_data_offset, // A2 not use
794797
recv_data, token_server_idx, token_unique_per_server, ep_rank_token_cnt, recv_tokens_per_expert,
795-
src_offset_rank_token_idx, dst_offset_rank_token_idx, offset_inner, count_outer, expand_idx,
798+
src_offset_rank_token_idx, dst_offset_rank_token_idx, token_idx_per_expert, offset_inner, count_outer, expand_idx,
796799
total_recv_token);
797800

798801
int total_count = total_recv_token.item<int>();
@@ -808,7 +811,7 @@ Buffer::internode_dispatch(
808811
}
809812

810813
EXEC_NPU_CMD(aclnnDispatchNormalA2, new_x, expert_ids, x_scales, xActiveMask, new_topk_weights, token_server_idx,
811-
token_unique_per_server, ep_rank_token_cnt, src_offset_rank_token_idx, dst_offset_rank_token_idx,
814+
token_unique_per_server, ep_rank_token_cnt, src_offset_rank_token_idx, dst_offset_rank_token_idx, token_idx_per_expert,
812815
hcom_ep_name, num_ranks, rank, num_experts, hcom_ep_name, tp_size, tp_rank, expertShardType,
813816
sharedExpertNum, sharedExpertRankNum, quant_mode, global_bs, expertTokenNumsType, expandx_out,
814817
dynamic_scales_out, expand_idx, expertTokenNums, epRecvCount, expand_scales,

csrc/deepep/deep_ep.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ struct Buffer {
3434
at::Tensor new_scales;
3535
at::Tensor notify_send_data; // only for internode notify
3636
at::Tensor send_token_idx_small;
37+
at::Tensor tokens_per_rank;
3738
int notify_send_data_size; // only for internode notify
3839

3940
int64_t shared_expert_rank_num;

csrc/deepep/ops2/op_host/dispatch_normal_a2.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ class DispatchNormalA2 : public OpDef
6666
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
6767
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
6868
.AutoContiguous();
69+
this->Input("tokenIdxPerExpert")
70+
.ParamType(OPTIONAL)
71+
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
72+
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
73+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
74+
.AutoContiguous();
6975

7076
this->Output("recv_x")
7177
.ParamType(REQUIRED)

csrc/deepep/ops2/op_host/dispatch_normal_a2_tiling.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ constexpr uint32_t TOKEN_SERVER_CNT_INDEX = 6;
6363
constexpr uint32_t EP_RANK_TOKEN_CNT_INDEX = 7;
6464
constexpr uint32_t SRC_OFFSET_RANK_TOKEN_IDX_INDEX = 8;
6565
constexpr uint32_t DST_OFFSET_RANK_TOKEN_IDX_INDEX = 9;
66+
constexpr uint32_t TOKEN_IDX_PER_EXPERT_INDEX = 10;
6667
constexpr uint32_t OUTPUT_EXPAND_X_INDEX = 0;
6768
constexpr uint32_t OUTPUT_DYNAMIC_SCALES_INDEX = 1;
6869
constexpr uint32_t OUTPUT_EXPAND_IDX_INDEX = 2;
@@ -175,6 +176,7 @@ static bool CheckTensorDim(const gert::TilingContext &context, const char *nodeN
175176
return false);
176177
OP_LOGD(nodeName, "expertId dim0 = %ld", expertIdStorageShape->GetStorageShape().GetDim(0));
177178
OP_LOGD(nodeName, "expertId dim1 = %ld", expertIdStorageShape->GetStorageShape().GetDim(1));
179+
178180
// 如果scales不为空进行shape维度检查
179181
if (isScales) {
180182
const gert::StorageShape *scalesStorageShape = context.GetOptionalInputShape(SCALES_INDEX);
@@ -601,6 +603,7 @@ static ge::graphStatus CheckTensorShape(const gert::TilingContext &context, cons
601603
expertIdsDim1),
602604
return ge::GRAPH_FAILED);
603605
tilingData.moeDistributeDispatchInfo.k = static_cast<uint32_t>(expertIdsDim1);
606+
604607
// 校验scales的维度
605608
if (isScales) {
606609
const gert::StorageShape *scalesStorageShape = context.GetOptionalInputShape(SCALES_INDEX);
@@ -932,6 +935,10 @@ static ge::graphStatus MoeDistributeDispatchA2CheckShapeAndSetTiling(const gert:
932935
context.GetInputShape(DST_OFFSET_RANK_TOKEN_IDX_INDEX);
933936
OP_TILING_CHECK(dstOffsetRankTokenIdxStorageShape == nullptr,
934937
OP_LOGE(K_INNER_DEBUG, "dstOffsetRankTokenIdxStorageShape is null."), return GRAPH_FAILED);
938+
const gert::StorageShape *tokenIdxPerExpertStorageShape =
939+
context.GetInputShape(TOKEN_IDX_PER_EXPERT_INDEX);
940+
OP_TILING_CHECK(tokenIdxPerExpertStorageShape == nullptr,
941+
OP_LOGE(K_INNER_DEBUG, "tokenIdxPerExpertStorageShape is null."), return GRAPH_FAILED);
935942

936943
info.isQuant = isScales;
937944
info.bs = bs;

csrc/deepep/ops2/op_host/notify_dispatch_a2.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ class NotifyDispatchA2 : public OpDef
6161
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32})
6262
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
6363
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
64+
this->Output("tokenIdxPerExpert")
65+
.ParamType(REQUIRED)
66+
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32})
67+
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
6468
this->Output("offsetInner")
6569
.ParamType(REQUIRED)
6670
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32})

csrc/deepep/ops2/op_host/notify_dispatch_tiling_a2.cc

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,11 @@ constexpr uint32_t OUTPUT_EP_RANK_TOKEN_CNT_INDEX = 4;
6565
constexpr uint32_t OUTPUT_LOCAL_EP_TOKEN_CNT_INDEX = 5;
6666
constexpr uint32_t OUTPUT_SRC_OFFSET_RANK_TOKEN_INDEX = 6;
6767
constexpr uint32_t OUTPUT_DST_OFFSET_RANK_TOKEN_INDEX = 7;
68-
constexpr uint32_t OUTPUT_OFFSET_INNER_INDEX = 8;
69-
constexpr uint32_t OUTPUT_COUNT_OUTER_INDEX = 9;
70-
constexpr uint32_t OUTPUT_EXPAND_IDX_INDEX = 10;
71-
constexpr uint32_t OUTPUT_TOTAL_RECV_TOKENS_INDEX = 11;
68+
constexpr uint32_t TOKEN_IDX_PER_EXPERT_INDEX = 8;
69+
constexpr uint32_t OUTPUT_OFFSET_INNER_INDEX = 9;
70+
constexpr uint32_t OUTPUT_COUNT_OUTER_INDEX = 10;
71+
constexpr uint32_t OUTPUT_EXPAND_IDX_INDEX = 11;
72+
constexpr uint32_t OUTPUT_TOTAL_RECV_TOKENS_INDEX = 12;
7273

7374
constexpr uint32_t ATTR_SEND_COUNT_INDEX = 0;
7475
constexpr uint32_t ATTR_NUM_TOKENS_INDEX = 1;
@@ -327,6 +328,20 @@ static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeNa
327328
static_cast<ge::DataType>(dstOffsetRankTokenIdx->GetDataType())),
328329
return false);
329330

331+
auto tokenIdxPerExpert = context->GetOutputDesc(TOKEN_IDX_PER_EXPERT_INDEX);
332+
OP_TILING_CHECK(tokenIdxPerExpert == nullptr, OP_LOGE(nodeName, "tokenIdxPerExpert is null."),
333+
return false);
334+
OP_TILING_CHECK(
335+
(tokenIdxPerExpert->GetDataType() != ge::DT_BF16) &&
336+
(tokenIdxPerExpert->GetDataType() != ge::DT_FLOAT16) &&
337+
(tokenIdxPerExpert->GetDataType() != ge::DT_FLOAT) &&
338+
(tokenIdxPerExpert->GetDataType() != ge::DT_INT32),
339+
OP_LOGE(
340+
nodeName,
341+
"tokenIdxPerExpert datatype is invalid, datatype should be bf16 or float16 or float or int, but is %d.",
342+
static_cast<ge::DataType>(tokenIdxPerExpert->GetDataType())),
343+
return false);
344+
330345
auto offsetInner = context->GetOutputDesc(OUTPUT_OFFSET_INNER_INDEX);
331346
OP_TILING_CHECK(offsetInner == nullptr, OP_LOGE(nodeName, "offsetInner is null."), return false);
332347
OP_TILING_CHECK(

csrc/deepep/ops2/op_host/op_api/aclnn_dispatch_normal_a2.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ extern "C" {
1919
aclnnStatus aclnnDispatchNormalA2GetWorkspaceSize(
2020
const aclTensor *x, const aclTensor *expertIds, const aclTensor *scales, const aclTensor *xActiveMask,
2121
const aclTensor *expertScales, const aclTensor *tokenServerIdx, const aclTensor *tokenServerCnt,
22-
const aclTensor *epRankTokenCnt, const aclTensor *srcOffsetRankTokenIdx, const aclTensor *dstOffsetRankTokenIdx,
22+
const aclTensor *epRankTokenCnt, const aclTensor *srcOffsetRankTokenIdx, const aclTensor *dstOffsetRankTokenIdx, const aclTensor *tokenIdxPerExpert,
2323
char *groupEp, int64_t epWorldSize, int64_t epRankId, int64_t moeExpertNum, char *groupTp, int64_t tpWorldSize,
2424
int64_t tpRankId, int64_t expertShardType, int64_t sharedExpertNum, int64_t sharedExpertRankNum, int64_t quantMode,
2525
int64_t globalBs, int64_t expertTokenNumsType, const aclTensor *recvX, const aclTensor *dynamicScales,
@@ -29,7 +29,7 @@ aclnnStatus aclnnDispatchNormalA2GetWorkspaceSize(
2929
{
3030
return aclnnInnerDispatchNormalA2GetWorkspaceSize(
3131
x, expertIds, scales, xActiveMask, expertScales, tokenServerIdx, tokenServerCnt, epRankTokenCnt,
32-
srcOffsetRankTokenIdx, dstOffsetRankTokenIdx, groupEp, epWorldSize, epRankId, moeExpertNum, groupTp,
32+
srcOffsetRankTokenIdx, dstOffsetRankTokenIdx,tokenIdxPerExpert, groupEp, epWorldSize, epRankId, moeExpertNum, groupTp,
3333
tpWorldSize, tpRankId, expertShardType, sharedExpertNum, sharedExpertRankNum, quantMode, globalBs,
3434
expertTokenNumsType, recvX, dynamicScales, expandIdx, expertTokenNums, epRecvCount, expandScales,
3535
waitRecvCostStats, workspaceSize, executor);

csrc/deepep/ops2/op_host/op_api/aclnn_dispatch_normal_a2.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ extern "C" {
1010
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchNormalA2GetWorkspaceSize(
1111
const aclTensor *x, const aclTensor *expertIds, const aclTensor *scales, const aclTensor *xActiveMask,
1212
const aclTensor *expertScales, const aclTensor *tokenServerIdx, const aclTensor *tokenServerCnt,
13-
const aclTensor *epRankTokenCnt, const aclTensor *srcOffsetRankTokenIdx, const aclTensor *dstOffsetRankTokenIdx,
13+
const aclTensor *epRankTokenCnt, const aclTensor *srcOffsetRankTokenIdx, const aclTensor *dstOffsetRankTokenIdx,const aclTensor *tokenIdxPerExpert,
1414
char *groupEp, int64_t epWorldSize, int64_t epRankId, int64_t moeExpertNum, char *groupTp, int64_t tpWorldSize,
1515
int64_t tpRankId, int64_t expertShardType, int64_t sharedExpertNum, int64_t sharedExpertRankNum, int64_t quantMode,
1616
int64_t globalBs, int64_t expertTokenNumsType, const aclTensor *recvX, const aclTensor *dynamicScales,

csrc/deepep/ops2/op_host/op_api/aclnn_notify_dispatch_a2.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@ aclnnStatus aclnnNotifyDispatchA2GetWorkspaceSize(
2020
int64_t numTokens, int64_t topkNum, int64_t numExperts, char *commGroup, int64_t rankSize, int64_t rankId,
2121
int64_t localRankSize, int64_t localRankId, const aclTensor *sendDataOffset, const aclTensor *recvData,
2222
const aclTensor *tokenServerIdx, const aclTensor *tokenUniquePerServer, const aclTensor *epRankTokenCnt,
23-
const aclTensor *localEpTokenCnt, const aclTensor *srcOffsetRankTokenIdx, const aclTensor *dstOffsetRankTokenIdx,
23+
const aclTensor *localEpTokenCnt, const aclTensor *srcOffsetRankTokenIdx, const aclTensor *dstOffsetRankTokenIdx,const aclTensor *tokenIdxPerExpert,
2424
const aclTensor *offsetInner, const aclTensor *countOuter, const aclTensor *expandIdx,
2525
const aclTensor *totalRecvTokens, uint64_t *workspaceSize, aclOpExecutor **executor)
2626
{
2727
return aclnnInnerNotifyDispatchA2GetWorkspaceSize(
2828
sendData, tokenPerExpertData, tmpData, sendCount, numTokens, topkNum, numExperts, commGroup, rankSize, rankId,
2929
localRankSize, localRankId, sendDataOffset, recvData, tokenServerIdx, tokenUniquePerServer, epRankTokenCnt,
30-
localEpTokenCnt, srcOffsetRankTokenIdx, dstOffsetRankTokenIdx, offsetInner, countOuter, expandIdx,
30+
localEpTokenCnt, srcOffsetRankTokenIdx, dstOffsetRankTokenIdx, tokenIdxPerExpert, offsetInner, countOuter, expandIdx,
3131
totalRecvTokens, workspaceSize, executor);
3232
}
3333

csrc/deepep/ops2/op_host/op_api/aclnn_notify_dispatch_a2.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ extern "C" {
2929
* localEpTokenCnt : required
3030
* srcOffsetRankTokenIdx : required
3131
* dstOffsetRankTokenIdx : required
32+
* tokenIdxPerExpert : required
3233
* offsetInner : required
3334
* countOuter : required
3435
* expandIdx : required
@@ -40,7 +41,7 @@ __attribute__((visibility("default"))) aclnnStatus aclnnNotifyDispatchA2GetWorks
4041
int64_t numTokens, int64_t topkNum, int64_t numExperts, char *commGroup, int64_t rankSize, int64_t rankId,
4142
int64_t localRankSize, int64_t localRankId, const aclTensor *sendDataOffset, const aclTensor *recvData,
4243
const aclTensor *tokenServerIdx, const aclTensor *tokenUniquePerServer, const aclTensor *epRankTokenCnt,
43-
const aclTensor *localEpTokenCnt, const aclTensor *srcOffsetRankTokenIdx, const aclTensor *dstOffsetRankTokenIdx,
44+
const aclTensor *localEpTokenCnt, const aclTensor *srcOffsetRankTokenIdx, const aclTensor *dstOffsetRankTokenIdx,const aclTensor *tokenIdxPerExpert,
4445
const aclTensor *offsetInner, const aclTensor *countOuter, const aclTensor *expandIdx,
4546
const aclTensor *totalRecvTokens, uint64_t *workspaceSize, aclOpExecutor **executor);
4647

0 commit comments

Comments
 (0)