Skip to content

Commit 31305dd

Browse files
chuangz0pcastonguay
authored andcommitted
fix_kvcache_split
Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
1 parent ea640a1 commit 31305dd

File tree

1 file changed

+17
-13
lines changed

1 file changed

+17
-13
lines changed

cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.cu

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -615,12 +615,13 @@ __global__ void splitKVCacheForMLAKernel(T const** __restrict__ inputBlocks, T**
615615
T* outputCachePtr = outputCaches[outputCacheIdx];
616616

617617
int const headIdInDomainTP = headId;
618-
int const blockIdInDomainCP = blockId / domainCPSize;
618+
int64_t const blockIdInDomainCP = blockId / domainCPSize;
619619

620620
T* kOutputPtr = outputCachePtr
621-
+ blockIdInDomainCP * (layerNumInSpecPP * kvFactor * headNum * tokensPerBlock * dimsPerHead)
622-
+ layerIdInDomainPP * kvFactor * headNum * tokensPerBlock * dimsPerHead
623-
+ headIdInDomainTP * tokensPerBlock * dimsPerHead;
621+
+ blockIdInDomainCP
622+
* (static_cast<int64_t>(layerNumInSpecPP * kvFactor * headNum * tokensPerBlock * dimsPerHead))
623+
+ static_cast<int64_t>(layerIdInDomainPP) * kvFactor * headNum * tokensPerBlock * dimsPerHead
624+
+ static_cast<int64_t>(headIdInDomainTP) * tokensPerBlock * dimsPerHead;
624625
int const kvOffset = headNum * tokensPerBlock * dimsPerHead;
625626
#pragma unroll 1
626627
for (int tokenId = subWarpId; tokenId < tokensPerBlock; tokenId += subWarpNum)
@@ -698,9 +699,10 @@ __global__ void splitKVCacheKernel(T const** __restrict__ inputBlocks, T** __res
698699

699700
int headIdInDomainTP = headId % headNumDomainTP;
700701
T* kOutputPtr = outputCachePtr
701-
+ blockId * (layerNumInSpecPP * 2 * headNumDomainTP * tokensPerBlock * dimsPerHead)
702-
+ layerIdInDomainPP * 2 * headNumDomainTP * tokensPerBlock * dimsPerHead
703-
+ headIdInDomainTP * tokensPerBlock * dimsPerHead;
702+
+ static_cast<int64_t>(blockId)
703+
* (static_cast<int64_t>(layerNumInSpecPP * 2 * headNumDomainTP * tokensPerBlock * dimsPerHead))
704+
+ static_cast<int64_t>(layerIdInDomainPP) * 2 * headNumDomainTP * tokensPerBlock * dimsPerHead
705+
+ static_cast<int64_t>(headIdInDomainTP) * tokensPerBlock * dimsPerHead;
704706

705707
T* vOutputPtr = kOutputPtr + headNumDomainTP * tokensPerBlock * dimsPerHead;
706708
#pragma unroll 1
@@ -872,9 +874,10 @@ __global__ void concatKVCacheForMLAKernel(T const** __restrict__ inputCaches, T*
872874
int headIdInDomainTP = headId;
873875

874876
T const* kInputPtr = inputCachePtr
875-
+ blockId * (layerNumInSpecPP * kvFactor * headNum * tokensPerBlock * dimsPerHead)
876-
+ layerIdInDomainPP * kvFactor * headNum * tokensPerBlock * dimsPerHead
877-
+ headIdInDomainTP * tokensPerBlock * dimsPerHead;
877+
+ static_cast<int64_t>(blockId)
878+
* (static_cast<int64_t>(layerNumInSpecPP * kvFactor * headNum * tokensPerBlock * dimsPerHead))
879+
+ static_cast<int64_t>(layerIdInDomainPP) * kvFactor * headNum * tokensPerBlock * dimsPerHead
880+
+ static_cast<int64_t>(headIdInDomainTP) * tokensPerBlock * dimsPerHead;
878881
int const kvOffset = headNum * tokensPerBlock * dimsPerHead;
879882
#pragma unroll 1
880883
for (int tokenId = subWarpId; tokenId < tokensPerBlock; tokenId += subWarpNum)
@@ -939,9 +942,10 @@ __global__ void concatKVCacheKernel(T const** __restrict__ inputCaches, T** __re
939942

940943
int headIdInDomainTP = headId % headNumDomainTP;
941944
T const* kInputPtr = inputCachePtr
942-
+ blockId * (layerNumInSpecPP * 2 * headNumDomainTP * tokensPerBlock * dimsPerHead)
943-
+ layerIdInDomainPP * 2 * headNumDomainTP * tokensPerBlock * dimsPerHead
944-
+ headIdInDomainTP * tokensPerBlock * dimsPerHead;
945+
+ static_cast<int64_t>(blockId)
946+
* (static_cast<int64_t>(layerNumInSpecPP * 2 * headNumDomainTP * tokensPerBlock * dimsPerHead))
947+
+ static_cast<int64_t>(layerIdInDomainPP) * 2 * headNumDomainTP * tokensPerBlock * dimsPerHead
948+
+ static_cast<int64_t>(headIdInDomainTP) * tokensPerBlock * dimsPerHead;
945949

946950
T const* vInputPtr = kInputPtr + headNumDomainTP * tokensPerBlock * dimsPerHead;
947951
#pragma unroll 1

0 commit comments

Comments
 (0)