Skip to content

Commit cabcd9b

Browse files
committed
Halved shared memory 466.
1 parent 30d03e0 commit cabcd9b

File tree

1 file changed

+43
-27
lines changed

1 file changed

+43
-27
lines changed

csrc/kernels.cu

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3053,25 +3053,23 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
30533053
//// Allocate shared memory for BlockReduce
30543054
//__shared__ typename BlockReduce::TempStorage reduce;
30553055
int col_offset = blockIdx.x *8;
3056+
const int warp_id = threadIdx.x / 32;
30563057
const int half_warp_id = threadIdx.x / 16;
30573058
const int half_warp_lane = threadIdx.x % 16;
30583059

3059-
T local_A[64/BITS];
3060-
T local_B[64/BITS];
3061-
T local_C[8];
3060+
T local_A[1];
3061+
T local_B[8];
30623062

30633063
const int a_tile_offset = 32*16 + 16;
30643064
const int b_tile_offset = 16*8 + 16;
30653065
const int c_tile_offset = 32*8 + 24;
30663066

3067-
__shared__ T smem_A[WARPS*32*16*2 + (16*(WARPS-1))];
3068-
__shared__ T smem_B[WARPS*16*8*2 + (16*(WARPS-1))];
3067+
__shared__ T smem_A[WARPS*32*16 + (16*(WARPS-1))];
3068+
__shared__ T smem_B[WARPS*16*8 + (16*(WARPS-1))];
30693069
__shared__ T smem_C[WARPS*32*8 + (24*(WARPS-1))];
30703070

30713071
wmma::fragment<wmma::matrix_a, 32, 8, 16, half, wmma::row_major> a_frag;
30723072
wmma::fragment<wmma::matrix_b, 32, 8, 16, half, wmma::col_major> b_frag;
3073-
wmma::fragment<wmma::matrix_a, 32, 8, 16, half, wmma::row_major> a2_frag;
3074-
wmma::fragment<wmma::matrix_b, 32, 8, 16, half, wmma::col_major> b2_frag;
30753073
wmma::fragment<wmma::accumulator, 32, 8, 16, half> c_frag;
30763074

30773075
wmma::fill_fragment(c_frag, 0.0f);
@@ -3087,37 +3085,55 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
30873085
smem_C[i] = T(0);
30883086
__syncthreads();
30893087

3090-
#pragma unroll 8
3091-
for(int k = 0; k < 8; k++)
3092-
local_C[k] = T(0);
3088+
//#pragma unroll 8
3089+
//for(int k = 0; k < 8; k++)
3090+
//local_C[k] = T(0);
30933091

30943092
//int block_idx = 0;
30953093
//for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x)
30963094
for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x)
30973095
{
30983096
int idx = base_idx + threadIdx.x;
30993097

3100-
if(idx >= K)
3098+
for(int k = 0; k < 2; k++)
31013099
{
3102-
smem_A[threadIdx.x] = 0.0f;
3103-
//smem_B[threadIdx.x] = 0.0f;
3104-
}
3105-
else
3106-
{
3107-
smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = A[idx];
3100+
if(k == 0)
3101+
{
3102+
if(idx < K)
3103+
{
3104+
local_A[0] = A[idx];
31083105

3109-
for(int col = 0; col < 8; col++)
3110-
smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = B[(col_offset+col)*ldb+idx];
3111-
}
3106+
#pragma unroll 8
3107+
for(int col = 0; col < 8; col++)
3108+
local_B[col] = B[(col_offset+col)*ldb+idx];
3109+
}
31123110

3113-
__syncthreads();
3111+
}
3112+
3113+
if(idx >= K)
3114+
{
3115+
smem_A[threadIdx.x] = 0.0f;
3116+
//smem_B[threadIdx.x] = 0.0f;
3117+
}
3118+
else
3119+
{
3120+
if((k == 0 && half_warp_id % 2 == 0) ||
3121+
(k == 1 && half_warp_id % 2 == 1))
3122+
{
3123+
smem_A[half_warp_lane + (warp_id*a_tile_offset)] = local_A[0];
31143124

3115-
wmma::load_matrix_sync(a_frag, &(smem_A[0]), 16); // 111 mu
3116-
wmma::load_matrix_sync(b_frag, &(smem_B[0]), 16); // 35 mu
3117-
wmma::load_matrix_sync(a2_frag, &(smem_A[half_warp_id*a_tile_offset]), 16); // 111 mu
3118-
wmma::load_matrix_sync(b2_frag, &(smem_B[half_warp_id*b_tile_offset]), 16); // 35 mu
3119-
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
3120-
wmma::mma_sync(c_frag, a2_frag, b2_frag, c_frag);
3125+
#pragma unroll 8
3126+
for(int col = 0; col < 8; col++)
3127+
smem_B[half_warp_lane + (warp_id*b_tile_offset) + (col*16)] = local_B[col];
3128+
}
3129+
}
3130+
3131+
__syncthreads();
3132+
3133+
wmma::load_matrix_sync(a_frag, &(smem_A[warp_id*a_tile_offset]), 16); // 111 mu
3134+
wmma::load_matrix_sync(b_frag, &(smem_B[warp_id*b_tile_offset]), 16); // 35 mu
3135+
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
3136+
}
31213137
}
31223138

31233139
// 129 mu

0 commit comments

Comments
 (0)