@@ -41,16 +41,17 @@ __global__ void __launch_bounds__(splitD, 1)
4141 __shared__ float smemC[N];
4242
4343#ifdef USE_CUB
44- using BlockLoadA = cub::BlockLoad<float , splitD, N, cub::BLOCK_LOAD_VECTORIZE>;
45- using BlockLoadS0 = cub::BlockLoad<float , splitD, N, cub::BLOCK_LOAD_VECTORIZE>;
46- using BlockStoreS = cub::BlockStore<float , splitD, N, cub::BLOCK_STORE_VECTORIZE>;
44+ using BlockLoad = cub::BlockLoad<float , splitD, N, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
45+ using BlockStore = cub::BlockStore<float , splitD, N, cub::BLOCK_STORE_WARP_TRANSPOSE>;
4746
48- __shared__ typename BlockLoadA::TempStorage block_load_tempA;
49- __shared__ typename BlockLoadS0::TempStorage block_load_tempS0;
50- __shared__ typename BlockStoreS::TempStorage block_store_tempS;
47+ union CubTempStorage {
48+ typename BlockLoad::TempStorage load_temp;
49+ typename BlockStore::TempStorage store_temp;
50+ };
51+ __shared__ CubTempStorage cub_temp_storage;
5152
52- BlockLoadA (block_load_tempA ).Load (A_block, regA);
53- BlockLoadS0 (block_load_tempS0 ).Load (s0_block, regs0);
53+ BlockLoad (cub_temp_storage. load_temp ).Load (A_block, regA);
54+ BlockLoad (cub_temp_storage. load_temp ).Load (s0_block, regs0);
5455#else
5556 const int stride_s0 = src0_nb1 / sizeof (float );
5657 const int stride_A = src3_nb1 / sizeof (float );
@@ -91,7 +92,7 @@ __global__ void __launch_bounds__(splitD, 1)
9192 }
9293
9394#ifdef USE_CUB
94- BlockStoreS (block_store_tempS ).Store (s_block, regs0);
95+ BlockStore (cub_temp_storage. store_temp ).Store (s_block, regs0);
9596#else
9697 const int stride_s = stride_s0;
9798#pragma unroll
0 commit comments