Skip to content

Commit 3d4a2ea

Browse files
committed
16x16 240.
1 parent 7cc8ff4 commit 3d4a2ea

File tree

2 files changed

+27
-27
lines changed

2 files changed

+27
-27
lines changed

csrc/kernels.cu

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3052,37 +3052,37 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
30523052
//typedef cub::BlockReduce<T, THREADS> BlockReduce;
30533053
//// Allocate shared memory for BlockReduce
30543054
//__shared__ typename BlockReduce::TempStorage reduce;
3055-
int col_offset = blockIdx.x *8;
3055+
int col_offset = blockIdx.x *16;
30563056
const int warp_id = threadIdx.x / 32;
30573057
const int half_warp_id = threadIdx.x / 16;
30583058
const int half_warp_lane = threadIdx.x % 16;
30593059
const int batch_size_warps = (WARPS-1)*2;
30603060

30613061
T local_A[1];
3062-
T local_B[8];
3062+
T local_B[16];
30633063

3064-
const int a_tile_offset = (32*16 + 16);
3065-
const int b_tile_offset = (16*8 + 16);
3066-
const int c_tile_offset = 32*8 + 24;
3064+
const int a_tile_offset = (16*16 + 16);
3065+
const int b_tile_offset = (16*16 + 16);
3066+
const int c_tile_offset = 16*16 + 24;
30673067

3068-
__shared__ T smem_A[2*batch_size_warps*32*16 + (2*16*(batch_size_warps-1))];
3069-
__shared__ T smem_B[2*batch_size_warps*16*8 + (2*16*(batch_size_warps-1))];
3070-
__shared__ T smem_C[32*8];
3068+
__shared__ T smem_A[2*batch_size_warps*16*16 + (2*16*(batch_size_warps-1))];
3069+
__shared__ T smem_B[2*batch_size_warps*16*16 + (2*16*(batch_size_warps-1))];
3070+
__shared__ T smem_C[16*16];
30713071

3072-
wmma::fragment<wmma::matrix_a, 32, 8, 16, half, wmma::row_major> a_frag;
3073-
wmma::fragment<wmma::matrix_b, 32, 8, 16, half, wmma::col_major> b_frag;
3074-
wmma::fragment<wmma::accumulator, 32, 8, 16, half> c_frag;
3072+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> a_frag;
3073+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::col_major> b_frag;
3074+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> c_frag;
30753075

30763076
wmma::fill_fragment(c_frag, 0.0f);
30773077

30783078

3079-
for(int i = threadIdx.x; i < 32*16*WARPS; i+=blockDim.x)
3080-
smem_A[i] = T(0);
3079+
//for(int i = threadIdx.x; i < 16*16*WARPS; i+=blockDim.x)
3080+
// smem_A[i] = T(0);
30813081

3082-
for(int i = threadIdx.x; i < 32*8*WARPS; i+=blockDim.x)
3083-
smem_B[i] = T(0);
3082+
//for(int i = threadIdx.x; i < 16*16*WARPS; i+=blockDim.x)
3083+
// smem_B[i] = T(0);
30843084

3085-
for(int i = threadIdx.x; i < 32*8*WARPS; i+=blockDim.x)
3085+
for(int i = threadIdx.x; i < 16*16; i+=blockDim.x)
30863086
smem_C[i] = T(0);
30873087
__syncthreads();
30883088

@@ -3099,14 +3099,14 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
30993099
{
31003100
local_A[0] = A[idx];
31013101

3102-
#pragma unroll 8
3103-
for(int col = 0; col < 8; col++)
3102+
#pragma unroll 16
3103+
for(int col = 0; col < 16; col++)
31043104
local_B[col] = B[(col_offset+col)*ldb+idx];
31053105

31063106
smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = local_A[0];
31073107

3108-
#pragma unroll 8
3109-
for(int col = 0; col < 8; col++)
3108+
#pragma unroll 16
3109+
for(int col = 0; col < 16; col++)
31103110
smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = local_B[col];
31113111
}
31123112
ticktock = ticktock == 0 ? 1 : 0;
@@ -3120,14 +3120,14 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
31203120
{
31213121
local_A[0] = A[idx];
31223122

3123-
#pragma unroll 8
3124-
for(int col = 0; col < 8; col++)
3123+
#pragma unroll 16
3124+
for(int col = 0; col < 16; col++)
31253125
local_B[col] = B[(col_offset+col)*ldb+idx];
31263126

31273127
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
31283128

3129-
#pragma unroll 8
3130-
for(int col = 0; col < 8; col++)
3129+
#pragma unroll 16
3130+
for(int col = 0; col < 16; col++)
31313131
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col];
31323132
}
31333133
ticktock = ticktock == 0 ? 1 : 0;
@@ -3143,7 +3143,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
31433143

31443144
// 129 mu
31453145
if(warp_id == (WARPS-1))
3146-
wmma::store_matrix_sync(&(smem_C[0]), c_frag, 8, wmma::mem_row_major);
3146+
wmma::store_matrix_sync(&(smem_C[0]), c_frag, 16, wmma::mem_row_major);
31473147
__syncthreads();
31483148

31493149
//if(threadIdx.x >= 16){ return; }
@@ -3185,7 +3185,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
31853185

31863186
//if(threadIdx.x < 8 && col_offset + threadIdx.x < M)
31873187
//out[col_offset + threadIdx.x ] = smem_C[threadIdx.x];
3188-
if(threadIdx.x < 8 && col_offset + threadIdx.x < M)
3188+
if(threadIdx.x < 16 && col_offset + threadIdx.x < M)
31893189
out[col_offset + threadIdx.x] = smem_C[threadIdx.x];
31903190
}
31913191

csrc/ops.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size)
678678
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits)
679679
{
680680

681-
int num_blocks = (m+7)/8;
681+
int num_blocks = (m+15)/16;
682682

683683
cout << num_blocks << endl;
684684
cout << lda << endl;

0 commit comments

Comments
 (0)