@@ -11,7 +11,7 @@ namespace MoeDistributeDispatchA2Impl {
1111constexpr uint32_t STATE_OFFSET = 512 ; // 状态空间偏移地址
1212constexpr uint32_t STATUS_SIZE_LAYERED = 1024 * 1024 ; // 1M
1313constexpr uint32_t HCCS_RING_BUFFER_HEAD_TAIL = 8 * 2 * 32 ;
14- constexpr uint32_t EACH_HCCS_RING_BUFFER_HEAD_TAIL = 8 * 2 * 32 ;
14+ constexpr uint32_t EACH_HCCS_RING_BUFFER_HEAD_TAIL = 2 * 32 ;
1515constexpr uint32_t RING_BUFFER_HEAD_TAIL = 8 * 32 ;
1616constexpr uint32_t RDMA_BUFFER_ALIGN = 4 * 1024 ;
1717constexpr uint32_t SELF_STATE_OFFSET = 512 * 1024 ; // 本卡状态空间偏移地址
@@ -129,9 +129,9 @@ class MoeDistributeDispatchA2Pipeline
129129
130130 LocalTensor<int32_t > tokenServerIdxTensor_;
131131 LocalTensor<int32_t > serverCountTensor_;
132- LocalTensor<uint32_t > tokenStructInRdmaTensor_;
133- LocalTensor<uint32_t > tokenStructInHccsTensor_;
134- LocalTensor<uint32_t > rdmaUseTokenStructInHccsTensor_;
132+ LocalTensor<uint8_t > tokenStructInRdmaTensor_;
133+ LocalTensor<uint8_t > tokenStructInHccsTensor_;
134+ LocalTensor<uint8_t > rdmaUseTokenStructInHccsTensor_;
135135
136136 TBuf<> tokenServerIdxBuf_;
137137 TBuf<> serverCountBuf_;
@@ -159,6 +159,7 @@ class MoeDistributeDispatchA2Pipeline
159159 GM_ADDR expertToServerCntGM_;
160160 GM_ADDR shareAddrs[8 ];
161161 GM_ADDR shareAddrWins[8 ];
162+ GM_ADDR hccsHeadTailGM[8 ];
162163
163164 // tiling侧已确保数据上限,相乘不会越界,因此统一采用uint32_t进行处理
164165 uint32_t axisBS_{0 };
@@ -316,12 +317,12 @@ __aicore__ inline void MoeDistributeDispatchA2Pipeline<TemplateMC2TypeA2Pipeline
316317 sendStatusTensor_.SetGlobalBuffer ((__gm__ int32_t *)(windowOutGM_ + WIN_SIZE));
317318 readStatusTensor_.SetGlobalBuffer ((__gm__ int32_t *)(windowInGM_ + WIN_SIZE));
318319 for (int i = 0 ; i < SERVER_RANK_SIZE; i++) {
319- hccsHeadTailTensor_ [i]. SetGlobalBuffer (( __gm__ int32_t *)(hccl_.GetWindowsInAddr (rankId_ / SERVER_RANK_SIZE * SERVER_RANK_SIZE + i) + halfWinSize_ -
320- EACH_HCCS_RING_BUFFER_HEAD_TAIL * i ));
320+ hccsHeadTailGM [i] = ( __gm__ uint8_t *)( reinterpret_cast < uint64_t > (hccl_.GetWindowsInAddr (rankId_ / SERVER_RANK_SIZE * SERVER_RANK_SIZE + i) + halfWinSize_ -
321+ EACH_HCCS_RING_BUFFER_HEAD_TAIL));
321322 }
322- hccsHeadTailTensor_.SetGlobalBuffer ((__gm__ int32_t *)(windowInGM_ + halfWinSize_ -
323- HCCS_RING_BUFFER_HEAD_TAIL));
324- rdmaHeadTailTensor_.SetGlobalBuffer ((__gm__ int32_t *)(windowInGM_ + halfWinSize_ - HCCS_RING_BUFFER_HEAD_TAIL -
323+ // hccsHeadTailTensor_.SetGlobalBuffer((__gm__ int32_t *)(windowInGM_ + halfWinSize_ -
324+ // HCCS_RING_BUFFER_HEAD_TAIL));
325+ rdmaHeadTailTensor_.SetGlobalBuffer ((__gm__ uint32_t *)(windowInGM_ + halfWinSize_ - HCCS_RING_BUFFER_HEAD_TAIL -
325326 RING_BUFFER_HEAD_TAIL * serverNum));
326327
327328 expertTokenNumsOutGM_ = expertTokenNumsOut; // 无GlobalTensor
@@ -357,13 +358,13 @@ __aicore__ inline void MoeDistributeDispatchA2Pipeline<TemplateMC2TypeA2Pipeline
357358 expertToServerIdxTensor_ = expertToServerIdxBuf_.Get <uint32_t >();
358359
359360 tpipe_->InitBuffer (tokenStructInRdmaBuf_, tokenLenInStruct_);
360- tokenStructInRdmaTensor_ = tokenStructInRdmaBuf_.Get <uint32_t >();
361+ tokenStructInRdmaTensor_ = tokenStructInRdmaBuf_.Get <uint8_t >();
361362
362363 tpipe_->InitBuffer (tokenStructInHccsBuf_, tokenLenInStruct_);
363- tokenStructInHccsTensor_ = tokenStructInHccsBuf_.Get <uint32_t >();
364+ tokenStructInHccsTensor_ = tokenStructInHccsBuf_.Get <uint8_t >();
364365
365366 tpipe_->InitBuffer (rdmaUseTokenStructInHccsBuf_, tokenLenInStruct_);
366- rdmaUseTokenStructInHccsTensor_ = rdmaUseTokenStructInHccsBuf_.Get <uint32_t >();
367+ rdmaUseTokenStructInHccsTensor_ = rdmaUseTokenStructInHccsBuf_.Get <uint8_t >();
367368
368369 tpipe_->InitBuffer (expertCountBuf_, moeExpertNum_ * sizeof (int32_t )); // moeNum * 4
369370 expertCountTensor_ = expertCountBuf_.Get <int32_t >();
@@ -466,7 +467,8 @@ __aicore__ inline void MoeDistributeDispatchA2Pipeline<TemplateMC2TypeA2Pipeline
466467 }
467468 taskEnd = taskStart + taskNumPerCore;
468469 DataCopyExtParams tokenStructParams{1 , static_cast <uint32_t >(tokenStructLen_), 0 , 0 , 0 };
469- DataCopyPadExtParams<uint32_t > tokenStructPadParams{false , 0U , 0U , 0U };
470+ DataCopyPadExtParams<uint8_t > tokenStructPadParams{false , 0U , 0U , 0U };
471+ DataCopyParams hccsHesdTailParams{2 , sizeof (uint32_t ), 0 , 0 };
470472 uint32_t processedTokenNum = 0 ;
471473 uint32_t tokenGlobalCnt = 0 ;
472474 for (int i = taskStart; i < taskEnd; i++) {
@@ -507,16 +509,20 @@ __aicore__ inline void MoeDistributeDispatchA2Pipeline<TemplateMC2TypeA2Pipeline
507509 continue ;
508510 }
509511 uint32_t localDstRank = (dstExpert - expertIdxStart) / localMoeExpertNum_;
510- GlobalTensor<uint8_t > localDstRankRecvRingU8Tensor;
511- localDstRankRecvRingU8Tensor.SetGlobalBuffer ((__gm__ uint8_t *) (hccl_.GetWindowsInAddr (rankId_)) + halfWinSize_ / 2 );
512- uint32_t hcclTail = hccsHeadTailTensor_.GetValue (localDstRank * 2 + 1 ); // localDstRank * 2为第localDstRank个rank的hccl头尾,0为hccl头,1为hccl尾
513- uint32_t hcclHead = hccsHeadTailTensor_.GetValue (localDstRank * 2 );
512+ GlobalTensor<uint8_t > dstRankRecvRingU8Tensor;
513+ dstRankRecvRingU8Tensor.SetGlobalBuffer ((__gm__ uint8_t *) (hccl_.GetWindowsInAddr (rankId_)) + halfWinSize_ / 2 );
514+ LocalTensor<uint32_t > localHccsHeadTailTensor;
515+ GlobalTensor<uint32_t > globalHccsHeadTailTensor;
516+ globalHccsHeadTailTensor.SetGlobalBuffer ((__gm__ uint32_t *)hccsHeadTailGM[rankId_ / SERVER_RANK_SIZE * SERVER_RANK_SIZE + i]);
517+ DataCopy (localHccsHeadTailTensor, globalHccsHeadTailTensor[localDstRank], hccsHesdTailParams);
518+ uint32_t hcclTail = localHccsHeadTailTensor.GetValue (1 );
519+ uint32_t hcclHead = localHccsHeadTailTensor.GetValue (0 );
514520 uint32_t index = 0 ;
515- while (hcclHead == (hcclTail + 1 ) % hccsItemNum && !Ascend::AtomicCas (address + localDstRank, 0 , 1 )) { // 谁抢到锁谁出循环
516- hcclHead = hccsHeadTailTensor_ .GetValue (localDstRank * 2 ); // 优化点,当前处理完一整个token后再进行下一个token的处理,此处可以有优化空间,尝试跳过无空闲的hccs环形缓冲区
521+ while (hcclHead == (hcclTail + 1 ) % hccsItemNum) { // 谁抢到锁谁出循环 && !Ascend::AtomicCas(address + localDstRank, 0, 1)
522+ hcclHead = localHccsHeadTailTensor .GetValue (0 ); // 优化点,当前处理完一整个token后再进行下一个token的处理,此处可以有优化空间,尝试跳过无空闲的hccs环形缓冲区
517523 }
518524 for (int k = 0 ; k < hccsItemNum; k++) {
519- DataCopyPad (rdmaUseTokenStructInHccsTensor_, localDstRankRecvRingU8Tensor [k * tokenStructLen_],
525+ DataCopyPad (rdmaUseTokenStructInHccsTensor_, dstRankRecvRingU8Tensor [k * tokenStructLen_],
520526 tokenStructParams, tokenStructPadParams);
521527 LocalTensor<int > tokenIdTensor = rdmaUseTokenStructInHccsTensor_[cntOffsetInStruct_].ReinterpretCast <int >();
522528 int tokenId = tokenIdTensor.GetValue (0 );
@@ -526,15 +532,16 @@ __aicore__ inline void MoeDistributeDispatchA2Pipeline<TemplateMC2TypeA2Pipeline
526532 }
527533 }
528534 SyncFunc<AscendC::HardEvent::S_MTE3>();
529- DataCopyPad (localDstRankRecvRingU8Tensor [tokenStructLen_ * index], tokenStructInRdmaTensor_,
530- tokenStructParams, tokenStructPadParams );
535+ DataCopyPad (dstRankRecvRingU8Tensor [tokenStructLen_ * index], tokenStructInRdmaTensor_,
536+ tokenStructParams);
531537 tokenIdxInStructTensor.SetValue (0 , -1 );
532538 DataCopyPad (rdmaRecvRingU8Tensor_[(i * rdmaItemNum + rdmaHead) * tokenStructLen_], tokenStructInRdmaTensor_,
533- tokenStructParams, tokenStructPadParams );
539+ tokenStructParams);
534540 rdmaHead = (rdmaHead + 1 ) % rdmaItemNum;
535541 hcclTail = (hcclTail + 1 ) % hccsItemNum;
536542 rdmaHeadTailTensor_.SetValue (i * RING_BUFFER_HEAD_TAIL + 2 , rdmaHead);
537- hccsHeadTailTensor_[localDstRank * 2 ].SetValue (1 , hcclTail);
543+ localHccsHeadTailTensor.SetValue (1 , hcclTail);
544+ DataCopy (globalHccsHeadTailTensor[localDstRank], localHccsHeadTailTensor, hccsHesdTailParams);
538545 }
539546 processedTokenNum++;
540547 }
@@ -584,7 +591,7 @@ __aicore__ inline void MoeDistributeDispatchA2Pipeline<TemplateMC2TypeA2Pipeline
584591 uint32_t processedTokens = 0 ;
585592 DataCopyExtParams tokenStructParams{1 , static_cast <uint32_t >(tokenStructLen_), 0 , 0 , 0 };
586593 DataCopyExtParams tokenParams{1 , static_cast <uint32_t >(tokenLenInStruct_), 0 , 0 , 0 };
587- DataCopyPadExtParams<uint32_t > tokenStructPadParams{false , 0U , 0U , 0U };
594+ DataCopyPadExtParams<uint8_t > tokenStructPadParams{false , 0U , 0U , 0U };
588595 DataCopyExtParams weightParams{1 , static_cast <uint32_t >(sizeof (float )), 0 , 0 , 0 };
589596 DataCopyPadExtParams<uint32_t > weightExtParams{false , 0U , 0U , 0U };
590597 DataCopyExtParams scalesParams{1 , static_cast <uint32_t >(sizeof (float )), 0 , 0 , 0 };
@@ -601,14 +608,15 @@ __aicore__ inline void MoeDistributeDispatchA2Pipeline<TemplateMC2TypeA2Pipeline
601608 tokenStructParams, tokenStructPadParams);
602609 uint32_t expertIdxStart = localMoeExpertNum_ * rankId_;
603610 uint32_t expertIdxEnd = expertIdxStart + localMoeExpertNum_;
604- LocalTensor<int > tokenIdxInStructTensor = tokenStructInHccsTensor_[cntOffsetInStruct_].Reinterpret <int >();
611+ LocalTensor<int > tokenIdxInStructTensor = tokenStructInHccsTensor_[cntOffsetInStruct_].ReinterpretCast <int >();
612+ LocalTensor<uint8_t > tokenIdxInStructToGmTensor = tokenStructInHccsTensor_[cntOffsetInStruct_];
605613 uint32_t tokenIdx = tokenIdxInStructTensor.GetValue (0 );
606614 if (tokenIdx < 0 ) {
607615 continue ;
608616 }
609- LocalTensor<float > weightTensor = tokenStructInHccsTensor_[weightOffsetInStruct_].Reinterpret <float >();
610- LocalTensor<ExpandXOutType> tokenOutTensor = tokenStructInHccsTensor_.Reinterpret <ExpandXOutType>();
611- LocalTensor<int > topkIdxTensor = tokenStructInHccsTensor_[expOffsetInStruct_].Reinterpret <int >();
617+ LocalTensor<float > weightTensor = tokenStructInHccsTensor_[weightOffsetInStruct_].ReinterpretCast <float >();
618+ LocalTensor<ExpandXOutType> tokenOutTensor = tokenStructInHccsTensor_.ReinterpretCast <ExpandXOutType>();
619+ LocalTensor<int > topkIdxTensor = tokenStructInHccsTensor_[expOffsetInStruct_].ReinterpretCast <int >();
612620 uint32_t dstOffset = 0 ;
613621 for (int j = 0 ; j < axisK_; j++) {
614622 SyncFunc<AscendC::HardEvent::MTE3_S>();
@@ -622,15 +630,14 @@ __aicore__ inline void MoeDistributeDispatchA2Pipeline<TemplateMC2TypeA2Pipeline
622630 DataCopyPad (expandXOutGMTensor_[dstOffset], tokenOutTensor, tokenParams);
623631 // dynamic scales to output
624632 if constexpr (DynamicQuant) {
625- LocalTensor<float > quantTempUB = localUB [scaleOffsetInStruct_].ReinterpretCast <float >();
633+ LocalTensor<float > quantTempUB = tokenStructInHccsTensor_ [scaleOffsetInStruct_].ReinterpretCast <float >();
626634 DataCopyPad (dynamicScalesOutGMTensor_[dstOffset], quantTempUB, scalesParams);
627635 }
628636 }
629637 tokenIdxInStructTensor.SetValue (0 , -1 );
630- DataCopyPad (hccsRecvRingU8Tensor_[tokenStructLen_ * i], tokenIdxInStructTensor, tokenStructParams,
631- tokenStructPadParams);
638+ DataCopyPad (hccsRecvRingU8Tensor_[tokenStructLen_ * i], tokenIdxInStructToGmTensor, tokenStructParams);
632639 uint32_t hcclHead = hccsHeadTailTensor_.GetValue (localRankId * 2 ); // 需要一个锁,避免多个core同时更新本rank的head
633- hccsTailTensor_. SeteValue (localRankId, hcclHead + 1 );
640+ hccsHeadTailTensor_. SetValue (localRankId, hcclHead + 1 );
634641 ++processedTokens;
635642 }
636643 }
0 commit comments