@@ -615,12 +615,13 @@ __global__ void splitKVCacheForMLAKernel(T const** __restrict__ inputBlocks, T**
615615 T* outputCachePtr = outputCaches[outputCacheIdx];
616616
617617 int const headIdInDomainTP = headId;
618- int const blockIdInDomainCP = blockId / domainCPSize;
618+ int64_t const blockIdInDomainCP = blockId / domainCPSize;
619619
620620 T* kOutputPtr = outputCachePtr
621- + blockIdInDomainCP * (layerNumInSpecPP * kvFactor * headNum * tokensPerBlock * dimsPerHead)
622- + layerIdInDomainPP * kvFactor * headNum * tokensPerBlock * dimsPerHead
623- + headIdInDomainTP * tokensPerBlock * dimsPerHead;
621+ + blockIdInDomainCP
622+ * (static_cast <int64_t >(layerNumInSpecPP * kvFactor * headNum * tokensPerBlock * dimsPerHead))
623+ + static_cast <int64_t >(layerIdInDomainPP) * kvFactor * headNum * tokensPerBlock * dimsPerHead
624+ + static_cast <int64_t >(headIdInDomainTP) * tokensPerBlock * dimsPerHead;
624625 int const kvOffset = headNum * tokensPerBlock * dimsPerHead;
625626#pragma unroll 1
626627 for (int tokenId = subWarpId; tokenId < tokensPerBlock; tokenId += subWarpNum)
@@ -698,9 +699,10 @@ __global__ void splitKVCacheKernel(T const** __restrict__ inputBlocks, T** __res
698699
699700 int headIdInDomainTP = headId % headNumDomainTP;
700701 T* kOutputPtr = outputCachePtr
701- + blockId * (layerNumInSpecPP * 2 * headNumDomainTP * tokensPerBlock * dimsPerHead)
702- + layerIdInDomainPP * 2 * headNumDomainTP * tokensPerBlock * dimsPerHead
703- + headIdInDomainTP * tokensPerBlock * dimsPerHead;
702+ + static_cast <int64_t >(blockId)
703+ * (static_cast <int64_t >(layerNumInSpecPP * 2 * headNumDomainTP * tokensPerBlock * dimsPerHead))
704+ + static_cast <int64_t >(layerIdInDomainPP) * 2 * headNumDomainTP * tokensPerBlock * dimsPerHead
705+ + static_cast <int64_t >(headIdInDomainTP) * tokensPerBlock * dimsPerHead;
704706
705707 T* vOutputPtr = kOutputPtr + headNumDomainTP * tokensPerBlock * dimsPerHead;
706708#pragma unroll 1
@@ -872,9 +874,10 @@ __global__ void concatKVCacheForMLAKernel(T const** __restrict__ inputCaches, T*
872874 int headIdInDomainTP = headId;
873875
874876 T const * kInputPtr = inputCachePtr
875- + blockId * (layerNumInSpecPP * kvFactor * headNum * tokensPerBlock * dimsPerHead)
876- + layerIdInDomainPP * kvFactor * headNum * tokensPerBlock * dimsPerHead
877- + headIdInDomainTP * tokensPerBlock * dimsPerHead;
877+ + static_cast <int64_t >(blockId)
878+ * (static_cast <int64_t >(layerNumInSpecPP * kvFactor * headNum * tokensPerBlock * dimsPerHead))
879+ + static_cast <int64_t >(layerIdInDomainPP) * kvFactor * headNum * tokensPerBlock * dimsPerHead
880+ + static_cast <int64_t >(headIdInDomainTP) * tokensPerBlock * dimsPerHead;
878881 int const kvOffset = headNum * tokensPerBlock * dimsPerHead;
879882#pragma unroll 1
880883 for (int tokenId = subWarpId; tokenId < tokensPerBlock; tokenId += subWarpNum)
@@ -939,9 +942,10 @@ __global__ void concatKVCacheKernel(T const** __restrict__ inputCaches, T** __re
939942
940943 int headIdInDomainTP = headId % headNumDomainTP;
941944 T const * kInputPtr = inputCachePtr
942- + blockId * (layerNumInSpecPP * 2 * headNumDomainTP * tokensPerBlock * dimsPerHead)
943- + layerIdInDomainPP * 2 * headNumDomainTP * tokensPerBlock * dimsPerHead
944- + headIdInDomainTP * tokensPerBlock * dimsPerHead;
945+ + static_cast <int64_t >(blockId)
946+ * (static_cast <int64_t >(layerNumInSpecPP * 2 * headNumDomainTP * tokensPerBlock * dimsPerHead))
947+ + static_cast <int64_t >(layerIdInDomainPP) * 2 * headNumDomainTP * tokensPerBlock * dimsPerHead
948+ + static_cast <int64_t >(headIdInDomainTP) * tokensPerBlock * dimsPerHead;
945949
946950 T const * vInputPtr = kInputPtr + headNumDomainTP * tokensPerBlock * dimsPerHead;
947951#pragma unroll 1
0 commit comments