@@ -3052,37 +3052,37 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
30523052 // typedef cub::BlockReduce<T, THREADS> BlockReduce;
30533053 // // Allocate shared memory for BlockReduce
30543054 // __shared__ typename BlockReduce::TempStorage reduce;
3055- int col_offset = blockIdx .x *8 ;
3055+ int col_offset = blockIdx .x *16 ;
30563056 const int warp_id = threadIdx .x / 32 ;
30573057 const int half_warp_id = threadIdx .x / 16 ;
30583058 const int half_warp_lane = threadIdx .x % 16 ;
30593059 const int batch_size_warps = (WARPS-1 )*2 ;
30603060
30613061 T local_A[1 ];
3062- T local_B[8 ];
3062+ T local_B[16 ];
30633063
3064- const int a_tile_offset = (32 *16 + 16 );
3065- const int b_tile_offset = (16 *8 + 16 );
3066- const int c_tile_offset = 32 * 8 + 24 ;
3064+ const int a_tile_offset = (16 *16 + 16 );
3065+ const int b_tile_offset = (16 *16 + 16 );
3066+ const int c_tile_offset = 16 * 16 + 24 ;
30673067
3068- __shared__ T smem_A[2 *batch_size_warps*32 *16 + (2 *16 *(batch_size_warps-1 ))];
3069- __shared__ T smem_B[2 *batch_size_warps*16 *8 + (2 *16 *(batch_size_warps-1 ))];
3070- __shared__ T smem_C[32 * 8 ];
3068+ __shared__ T smem_A[2 *batch_size_warps*16 *16 + (2 *16 *(batch_size_warps-1 ))];
3069+ __shared__ T smem_B[2 *batch_size_warps*16 *16 + (2 *16 *(batch_size_warps-1 ))];
3070+ __shared__ T smem_C[16 * 16 ];
30713071
3072- wmma::fragment<wmma::matrix_a, 32 , 8 , 16 , half, wmma::row_major> a_frag;
3073- wmma::fragment<wmma::matrix_b, 32 , 8 , 16 , half, wmma::col_major> b_frag;
3074- wmma::fragment<wmma::accumulator, 32 , 8 , 16 , half> c_frag;
3072+ wmma::fragment<wmma::matrix_a, 16 , 16 , 16 , half, wmma::row_major> a_frag;
3073+ wmma::fragment<wmma::matrix_b, 16 , 16 , 16 , half, wmma::col_major> b_frag;
3074+ wmma::fragment<wmma::accumulator, 16 , 16 , 16 , half> c_frag;
30753075
30763076 wmma::fill_fragment (c_frag, 0 .0f );
30773077
30783078
3079- for (int i = threadIdx .x ; i < 32 *16 *WARPS; i+=blockDim .x )
3080- smem_A[i] = T (0 );
3079+ // for(int i = threadIdx.x; i < 16 *16*WARPS; i+=blockDim.x)
3080+ // smem_A[i] = T(0);
30813081
3082- for (int i = threadIdx .x ; i < 32 * 8 *WARPS; i+=blockDim .x )
3083- smem_B[i] = T (0 );
3082+ // for(int i = threadIdx.x; i < 16*16 *WARPS; i+=blockDim.x)
3083+ // smem_B[i] = T(0);
30843084
3085- for (int i = threadIdx .x ; i < 32 * 8 *WARPS ; i+=blockDim .x )
3085+ for (int i = threadIdx .x ; i < 16 * 16 ; i+=blockDim .x )
30863086 smem_C[i] = T (0 );
30873087 __syncthreads ();
30883088
@@ -3099,14 +3099,14 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
30993099 {
31003100 local_A[0 ] = A[idx];
31013101
3102- #pragma unroll 8
3103- for (int col = 0 ; col < 8 ; col++)
3102+ #pragma unroll 16
3103+ for (int col = 0 ; col < 16 ; col++)
31043104 local_B[col] = B[(col_offset+col)*ldb+idx];
31053105
31063106 smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = local_A[0 ];
31073107
3108- #pragma unroll 8
3109- for (int col = 0 ; col < 8 ; col++)
3108+ #pragma unroll 16
3109+ for (int col = 0 ; col < 16 ; col++)
31103110 smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16 )] = local_B[col];
31113111 }
31123112 ticktock = ticktock == 0 ? 1 : 0 ;
@@ -3120,14 +3120,14 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
31203120 {
31213121 local_A[0 ] = A[idx];
31223122
3123- #pragma unroll 8
3124- for (int col = 0 ; col < 8 ; col++)
3123+ #pragma unroll 16
3124+ for (int col = 0 ; col < 16 ; col++)
31253125 local_B[col] = B[(col_offset+col)*ldb+idx];
31263126
31273127 smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0 ];
31283128
3129- #pragma unroll 8
3130- for (int col = 0 ; col < 8 ; col++)
3129+ #pragma unroll 16
3130+ for (int col = 0 ; col < 16 ; col++)
31313131 smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16 )] = local_B[col];
31323132 }
31333133 ticktock = ticktock == 0 ? 1 : 0 ;
@@ -3143,7 +3143,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
31433143
31443144 // 129 mu
31453145 if (warp_id == (WARPS-1 ))
3146- wmma::store_matrix_sync (&(smem_C[0 ]), c_frag, 8 , wmma::mem_row_major);
3146+ wmma::store_matrix_sync (&(smem_C[0 ]), c_frag, 16 , wmma::mem_row_major);
31473147 __syncthreads ();
31483148
31493149 // if(threadIdx.x >= 16){ return; }
@@ -3185,7 +3185,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
31853185
31863186 // if(threadIdx.x < 8 && col_offset + threadIdx.x < M)
31873187 // out[col_offset + threadIdx.x ] = smem_C[threadIdx.x];
3188- if (threadIdx .x < 8 && col_offset + threadIdx .x < M)
3188+ if (threadIdx .x < 16 && col_offset + threadIdx .x < M)
31893189 out[col_offset + threadIdx .x ] = smem_C[threadIdx .x ];
31903190}
31913191
0 commit comments