Skip to content

Commit 7d259d9

Browse files
committed
Use cub load and store warp transpose
1 parent 7e559f3 commit 7d259d9

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

ggml/src/ggml-cuda/ssm-scan.cu

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,17 @@ __global__ void __launch_bounds__(splitD, 1)
4141
__shared__ float smemC[N];
4242

4343
#ifdef USE_CUB
44-
using BlockLoadA = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_VECTORIZE>;
45-
using BlockLoadS0 = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_VECTORIZE>;
46-
using BlockStoreS = cub::BlockStore<float, splitD, N, cub::BLOCK_STORE_VECTORIZE>;
44+
using BlockLoad = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
45+
using BlockStore = cub::BlockStore<float, splitD, N, cub::BLOCK_STORE_WARP_TRANSPOSE>;
4746

48-
__shared__ typename BlockLoadA::TempStorage block_load_tempA;
49-
__shared__ typename BlockLoadS0::TempStorage block_load_tempS0;
50-
__shared__ typename BlockStoreS::TempStorage block_store_tempS;
47+
union CubTempStorage {
48+
typename BlockLoad::TempStorage load_temp;
49+
typename BlockStore::TempStorage store_temp;
50+
};
51+
__shared__ CubTempStorage cub_temp_storage;
5152

52-
BlockLoadA(block_load_tempA).Load(A_block, regA);
53-
BlockLoadS0(block_load_tempS0).Load(s0_block, regs0);
53+
BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA);
54+
BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0);
5455
#else
5556
const int stride_s0 = src0_nb1 / sizeof(float);
5657
const int stride_A = src3_nb1 / sizeof(float);
@@ -91,7 +92,7 @@ __global__ void __launch_bounds__(splitD, 1)
9192
}
9293

9394
#ifdef USE_CUB
94-
BlockStoreS(block_store_tempS).Store(s_block, regs0);
95+
BlockStore(cub_temp_storage.store_temp).Store(s_block, regs0);
9596
#else
9697
const int stride_s = stride_s0;
9798
#pragma unroll

0 commit comments

Comments
 (0)