Skip to content

Commit 4059c69

Browse files
committed
compile
1 parent a53f55d commit 4059c69

File tree

1 file changed

+40
-33
lines changed

1 file changed

+40
-33
lines changed

csrc/deepep/ops2/op_kernel/a2/moe_distribute_dispatch_a2_pipeline.h

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace MoeDistributeDispatchA2Impl {
1111
constexpr uint32_t STATE_OFFSET = 512; // 状态空间偏移地址
1212
constexpr uint32_t STATUS_SIZE_LAYERED = 1024 * 1024; // 1M
1313
constexpr uint32_t HCCS_RING_BUFFER_HEAD_TAIL = 8 * 2 * 32;
14-
constexpr uint32_t EACH_HCCS_RING_BUFFER_HEAD_TAIL = 8 * 2 * 32;
14+
constexpr uint32_t EACH_HCCS_RING_BUFFER_HEAD_TAIL = 2 * 32;
1515
constexpr uint32_t RING_BUFFER_HEAD_TAIL = 8 * 32;
1616
constexpr uint32_t RDMA_BUFFER_ALIGN = 4 * 1024;
1717
constexpr uint32_t SELF_STATE_OFFSET = 512 * 1024; // 本卡状态空间偏移地址
@@ -129,9 +129,9 @@ class MoeDistributeDispatchA2Pipeline
129129

130130
LocalTensor<int32_t> tokenServerIdxTensor_;
131131
LocalTensor<int32_t> serverCountTensor_;
132-
LocalTensor<uint32_t> tokenStructInRdmaTensor_;
133-
LocalTensor<uint32_t> tokenStructInHccsTensor_;
134-
LocalTensor<uint32_t> rdmaUseTokenStructInHccsTensor_;
132+
LocalTensor<uint8_t> tokenStructInRdmaTensor_;
133+
LocalTensor<uint8_t> tokenStructInHccsTensor_;
134+
LocalTensor<uint8_t> rdmaUseTokenStructInHccsTensor_;
135135

136136
TBuf<> tokenServerIdxBuf_;
137137
TBuf<> serverCountBuf_;
@@ -159,6 +159,7 @@ class MoeDistributeDispatchA2Pipeline
159159
GM_ADDR expertToServerCntGM_;
160160
GM_ADDR shareAddrs[8];
161161
GM_ADDR shareAddrWins[8];
162+
GM_ADDR hccsHeadTailGM[8];
162163

163164
// tiling侧已确保数据上限,相乘不会越界,因此统一采用uint32_t进行处理
164165
uint32_t axisBS_{0};
@@ -316,12 +317,12 @@ __aicore__ inline void MoeDistributeDispatchA2Pipeline<TemplateMC2TypeA2Pipeline
316317
sendStatusTensor_.SetGlobalBuffer((__gm__ int32_t *)(windowOutGM_ + WIN_SIZE));
317318
readStatusTensor_.SetGlobalBuffer((__gm__ int32_t *)(windowInGM_ + WIN_SIZE));
318319
for (int i = 0; i < SERVER_RANK_SIZE; i++) {
319-
hccsHeadTailTensor_[i].SetGlobalBuffer((__gm__ int32_t *)(hccl_.GetWindowsInAddr(rankId_ / SERVER_RANK_SIZE * SERVER_RANK_SIZE + i) + halfWinSize_ -
320-
EACH_HCCS_RING_BUFFER_HEAD_TAIL * i));
320+
hccsHeadTailGM[i] = (__gm__ uint8_t *)(reinterpret_cast<uint64_t>(hccl_.GetWindowsInAddr(rankId_ / SERVER_RANK_SIZE * SERVER_RANK_SIZE + i) + halfWinSize_ -
321+
EACH_HCCS_RING_BUFFER_HEAD_TAIL));
321322
}
322-
hccsHeadTailTensor_.SetGlobalBuffer((__gm__ int32_t *)(windowInGM_ + halfWinSize_ -
323-
HCCS_RING_BUFFER_HEAD_TAIL));
324-
rdmaHeadTailTensor_.SetGlobalBuffer((__gm__ int32_t *)(windowInGM_ + halfWinSize_ - HCCS_RING_BUFFER_HEAD_TAIL -
323+
// hccsHeadTailTensor_.SetGlobalBuffer((__gm__ int32_t *)(windowInGM_ + halfWinSize_ -
324+
// HCCS_RING_BUFFER_HEAD_TAIL));
325+
rdmaHeadTailTensor_.SetGlobalBuffer((__gm__ uint32_t *)(windowInGM_ + halfWinSize_ - HCCS_RING_BUFFER_HEAD_TAIL -
325326
RING_BUFFER_HEAD_TAIL * serverNum));
326327

327328
expertTokenNumsOutGM_ = expertTokenNumsOut; // 无GlobalTensor
@@ -357,13 +358,13 @@ __aicore__ inline void MoeDistributeDispatchA2Pipeline<TemplateMC2TypeA2Pipeline
357358
expertToServerIdxTensor_ = expertToServerIdxBuf_.Get<uint32_t>();
358359

359360
tpipe_->InitBuffer(tokenStructInRdmaBuf_, tokenLenInStruct_);
360-
tokenStructInRdmaTensor_ = tokenStructInRdmaBuf_.Get<uint32_t>();
361+
tokenStructInRdmaTensor_ = tokenStructInRdmaBuf_.Get<uint8_t>();
361362

362363
tpipe_->InitBuffer(tokenStructInHccsBuf_, tokenLenInStruct_);
363-
tokenStructInHccsTensor_ = tokenStructInHccsBuf_.Get<uint32_t>();
364+
tokenStructInHccsTensor_ = tokenStructInHccsBuf_.Get<uint8_t>();
364365

365366
tpipe_->InitBuffer(rdmaUseTokenStructInHccsBuf_, tokenLenInStruct_);
366-
rdmaUseTokenStructInHccsTensor_ = rdmaUseTokenStructInHccsBuf_.Get<uint32_t>();
367+
rdmaUseTokenStructInHccsTensor_ = rdmaUseTokenStructInHccsBuf_.Get<uint8_t>();
367368

368369
tpipe_->InitBuffer(expertCountBuf_, moeExpertNum_ * sizeof(int32_t)); // moeNum * 4
369370
expertCountTensor_ = expertCountBuf_.Get<int32_t>();
@@ -466,7 +467,8 @@ __aicore__ inline void MoeDistributeDispatchA2Pipeline<TemplateMC2TypeA2Pipeline
466467
}
467468
taskEnd = taskStart + taskNumPerCore;
468469
DataCopyExtParams tokenStructParams{1, static_cast<uint32_t>(tokenStructLen_), 0, 0, 0};
469-
DataCopyPadExtParams<uint32_t> tokenStructPadParams{false, 0U, 0U, 0U};
470+
DataCopyPadExtParams<uint8_t> tokenStructPadParams{false, 0U, 0U, 0U};
471+
DataCopyParams hccsHesdTailParams{2, sizeof(uint32_t), 0, 0};
470472
uint32_t processedTokenNum = 0;
471473
uint32_t tokenGlobalCnt = 0;
472474
for (int i = taskStart; i < taskEnd; i++) {
@@ -507,16 +509,20 @@ __aicore__ inline void MoeDistributeDispatchA2Pipeline<TemplateMC2TypeA2Pipeline
507509
continue;
508510
}
509511
uint32_t localDstRank = (dstExpert - expertIdxStart) / localMoeExpertNum_;
510-
GlobalTensor<uint8_t> localDstRankRecvRingU8Tensor;
511-
localDstRankRecvRingU8Tensor.SetGlobalBuffer((__gm__ uint8_t *) (hccl_.GetWindowsInAddr(rankId_)) + halfWinSize_ / 2);
512-
uint32_t hcclTail = hccsHeadTailTensor_.GetValue(localDstRank * 2 + 1); //localDstRank * 2为第localDstRank个rank的hccl头尾,0为hccl头,1为hccl尾
513-
uint32_t hcclHead = hccsHeadTailTensor_.GetValue(localDstRank * 2);
512+
GlobalTensor<uint8_t> dstRankRecvRingU8Tensor;
513+
dstRankRecvRingU8Tensor.SetGlobalBuffer((__gm__ uint8_t *) (hccl_.GetWindowsInAddr(rankId_)) + halfWinSize_ / 2);
514+
LocalTensor<uint32_t> localHccsHeadTailTensor;
515+
GlobalTensor<uint32_t> globalHccsHeadTailTensor;
516+
globalHccsHeadTailTensor.SetGlobalBuffer((__gm__ uint32_t *)hccsHeadTailGM[rankId_ / SERVER_RANK_SIZE * SERVER_RANK_SIZE + i]);
517+
DataCopy(localHccsHeadTailTensor, globalHccsHeadTailTensor[localDstRank], hccsHesdTailParams);
518+
uint32_t hcclTail = localHccsHeadTailTensor.GetValue(1);
519+
uint32_t hcclHead = localHccsHeadTailTensor.GetValue(0);
514520
uint32_t index = 0;
515-
while (hcclHead == (hcclTail + 1) % hccsItemNum && !Ascend::AtomicCas(address + localDstRank, 0, 1)) {//谁抢到锁谁出循环
516-
hcclHead = hccsHeadTailTensor_.GetValue(localDstRank * 2); //优化点,当前处理完一整个token后再进行下一个token的处理,此处可以有优化空间,尝试跳过无空闲的hccs环形缓冲区
521+
while (hcclHead == (hcclTail + 1) % hccsItemNum) {//谁抢到锁谁出循环 && !Ascend::AtomicCas(address + localDstRank, 0, 1)
522+
hcclHead = localHccsHeadTailTensor.GetValue(0); //优化点,当前处理完一整个token后再进行下一个token的处理,此处可以有优化空间,尝试跳过无空闲的hccs环形缓冲区
517523
}
518524
for (int k = 0; k < hccsItemNum; k++) {
519-
DataCopyPad(rdmaUseTokenStructInHccsTensor_, localDstRankRecvRingU8Tensor[k * tokenStructLen_],
525+
DataCopyPad(rdmaUseTokenStructInHccsTensor_, dstRankRecvRingU8Tensor[k * tokenStructLen_],
520526
tokenStructParams, tokenStructPadParams);
521527
LocalTensor<int> tokenIdTensor = rdmaUseTokenStructInHccsTensor_[cntOffsetInStruct_].ReinterpretCast<int>();
522528
int tokenId = tokenIdTensor.GetValue(0);
@@ -526,15 +532,16 @@ __aicore__ inline void MoeDistributeDispatchA2Pipeline<TemplateMC2TypeA2Pipeline
526532
}
527533
}
528534
SyncFunc<AscendC::HardEvent::S_MTE3>();
529-
DataCopyPad(localDstRankRecvRingU8Tensor[tokenStructLen_ * index], tokenStructInRdmaTensor_,
530-
tokenStructParams, tokenStructPadParams);
535+
DataCopyPad(dstRankRecvRingU8Tensor[tokenStructLen_ * index], tokenStructInRdmaTensor_,
536+
tokenStructParams);
531537
tokenIdxInStructTensor.SetValue(0, -1);
532538
DataCopyPad(rdmaRecvRingU8Tensor_[(i * rdmaItemNum + rdmaHead) * tokenStructLen_], tokenStructInRdmaTensor_,
533-
tokenStructParams, tokenStructPadParams);
539+
tokenStructParams);
534540
rdmaHead = (rdmaHead + 1) % rdmaItemNum;
535541
hcclTail = (hcclTail + 1) % hccsItemNum;
536542
rdmaHeadTailTensor_.SetValue(i * RING_BUFFER_HEAD_TAIL + 2, rdmaHead);
537-
hccsHeadTailTensor_[localDstRank * 2].SetValue(1, hcclTail);
543+
localHccsHeadTailTensor.SetValue(1, hcclTail);
544+
DataCopy(globalHccsHeadTailTensor[localDstRank], localHccsHeadTailTensor, hccsHesdTailParams);
538545
}
539546
processedTokenNum++;
540547
}
@@ -584,7 +591,7 @@ __aicore__ inline void MoeDistributeDispatchA2Pipeline<TemplateMC2TypeA2Pipeline
584591
uint32_t processedTokens = 0;
585592
DataCopyExtParams tokenStructParams{1, static_cast<uint32_t>(tokenStructLen_), 0, 0, 0};
586593
DataCopyExtParams tokenParams{1, static_cast<uint32_t>(tokenLenInStruct_), 0, 0, 0};
587-
DataCopyPadExtParams<uint32_t> tokenStructPadParams{false, 0U, 0U, 0U};
594+
DataCopyPadExtParams<uint8_t> tokenStructPadParams{false, 0U, 0U, 0U};
588595
DataCopyExtParams weightParams{1, static_cast<uint32_t>(sizeof(float)), 0, 0, 0};
589596
DataCopyPadExtParams<uint32_t> weightExtParams{false, 0U, 0U, 0U};
590597
DataCopyExtParams scalesParams{1, static_cast<uint32_t>(sizeof(float)), 0, 0, 0};
@@ -601,14 +608,15 @@ __aicore__ inline void MoeDistributeDispatchA2Pipeline<TemplateMC2TypeA2Pipeline
601608
tokenStructParams, tokenStructPadParams);
602609
uint32_t expertIdxStart = localMoeExpertNum_ * rankId_;
603610
uint32_t expertIdxEnd = expertIdxStart + localMoeExpertNum_;
604-
LocalTensor<int> tokenIdxInStructTensor = tokenStructInHccsTensor_[cntOffsetInStruct_].Reinterpret<int>();
611+
LocalTensor<int> tokenIdxInStructTensor = tokenStructInHccsTensor_[cntOffsetInStruct_].ReinterpretCast<int>();
612+
LocalTensor<uint8_t> tokenIdxInStructToGmTensor = tokenStructInHccsTensor_[cntOffsetInStruct_];
605613
uint32_t tokenIdx = tokenIdxInStructTensor.GetValue(0);
606614
if (tokenIdx < 0) {
607615
continue;
608616
}
609-
LocalTensor<float> weightTensor = tokenStructInHccsTensor_[weightOffsetInStruct_].Reinterpret<float>();
610-
LocalTensor<ExpandXOutType> tokenOutTensor = tokenStructInHccsTensor_.Reinterpret<ExpandXOutType>();
611-
LocalTensor<int> topkIdxTensor = tokenStructInHccsTensor_[expOffsetInStruct_].Reinterpret<int>();
617+
LocalTensor<float> weightTensor = tokenStructInHccsTensor_[weightOffsetInStruct_].ReinterpretCast<float>();
618+
LocalTensor<ExpandXOutType> tokenOutTensor = tokenStructInHccsTensor_.ReinterpretCast<ExpandXOutType>();
619+
LocalTensor<int> topkIdxTensor = tokenStructInHccsTensor_[expOffsetInStruct_].ReinterpretCast<int>();
612620
uint32_t dstOffset = 0;
613621
for (int j = 0; j < axisK_; j++) {
614622
SyncFunc<AscendC::HardEvent::MTE3_S>();
@@ -622,15 +630,14 @@ __aicore__ inline void MoeDistributeDispatchA2Pipeline<TemplateMC2TypeA2Pipeline
622630
DataCopyPad(expandXOutGMTensor_[dstOffset], tokenOutTensor, tokenParams);
623631
// dynamic scales to output
624632
if constexpr (DynamicQuant) {
625-
LocalTensor<float> quantTempUB = localUB[scaleOffsetInStruct_].ReinterpretCast<float>();
633+
LocalTensor<float> quantTempUB = tokenStructInHccsTensor_[scaleOffsetInStruct_].ReinterpretCast<float>();
626634
DataCopyPad(dynamicScalesOutGMTensor_[dstOffset], quantTempUB, scalesParams);
627635
}
628636
}
629637
tokenIdxInStructTensor.SetValue(0, -1);
630-
DataCopyPad(hccsRecvRingU8Tensor_[tokenStructLen_ * i], tokenIdxInStructTensor, tokenStructParams,
631-
tokenStructPadParams);
638+
DataCopyPad(hccsRecvRingU8Tensor_[tokenStructLen_ * i], tokenIdxInStructToGmTensor, tokenStructParams);
632639
uint32_t hcclHead = hccsHeadTailTensor_.GetValue(localRankId * 2); //需要一个锁,避免多个core同时更新本rank的head
633-
hccsTailTensor_.SeteValue(localRankId, hcclHead + 1);
640+
hccsHeadTailTensor_.SetValue(localRankId, hcclHead + 1);
634641
++processedTokens;
635642
}
636643
}

0 commit comments

Comments
 (0)