@@ -3061,8 +3061,8 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
30613061 T local_A[1 ];
30623062 T local_B[32 ];
30633063
3064- const int a_tile_offset = (8 *16 );
3065- const int b_tile_offset = (16 *32 );
3064+ const int a_tile_offset = (8 *16 + 16 );
3065+ const int b_tile_offset = (16 *32 + 16 );
30663066
30673067 __shared__ T smem_A[2 *batch_size_warps*8 *16 + (2 *16 *(batch_size_warps-1 ))];
30683068 __shared__ T smem_B[2 *batch_size_warps*16 *32 + (2 *16 *(batch_size_warps-1 ))];
@@ -3074,23 +3074,10 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
30743074
30753075 wmma::fill_fragment (c_frag, 0 .0f );
30763076
3077-
3078- // for(int i = threadIdx.x; i < 16*16*WARPS; i+=blockDim.x)
3079- // smem_A[i] = T(0);
3080-
3081- // for(int i = threadIdx.x; i < 16*16*WARPS; i+=blockDim.x)
3082- // smem_B[i] = T(0);
3083-
30843077 for (int i = threadIdx .x ; i < 8 *32 ; i+=blockDim .x )
30853078 smem_C[i] = T (0 );
30863079 __syncthreads ();
30873080
3088- // #pragma unroll 8
3089- // for(int k = 0; k < 8; k++)
3090- // local_C[k] = T(0);
3091-
3092- // int block_idx = 0;
3093- // for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x)
30943081 int ticktock = 0 ;
30953082 int idx = 0 + threadIdx .x ;
30963083 // prefetch
@@ -3102,29 +3089,29 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
31023089 for (int col = 0 ; col < 32 ; col++)
31033090 local_B[col] = B[(col_offset+col)*ldb+idx];
31043091
3105- smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = local_A[0 ];
3092+ smem_A[half_warp_lane + (((batch_size_warps*ticktock)+ half_warp_id) *a_tile_offset)] = local_A[0 ];
31063093
31073094 #pragma unroll 32
31083095 for (int col = 0 ; col < 32 ; col++)
3109- smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16 )] = local_B[col];
3096+ smem_B[half_warp_lane + (((batch_size_warps*ticktock)+ half_warp_id) *b_tile_offset) + (col*16 )] = local_B[col];
31103097 }
31113098 else if (warp_id < (WARPS-1 ))
31123099 {
31133100 local_A[0 ] = T (0.0 );
3114- smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = T ( 0.0 ) ;
3101+ smem_A[half_warp_lane + (((batch_size_warps*ticktock)+ half_warp_id) *a_tile_offset)] = 0 . 0f ;
31153102
31163103 #pragma unroll 32
31173104 for (int col = 0 ; col < 32 ; col++)
3118- local_B[col] = T ( 0 .0f ) ;
3105+ local_B[col] = 0 .0f ;
31193106
31203107 #pragma unroll 32
31213108 for (int col = 0 ; col < 32 ; col++)
3122- smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16 )] = T ( 0 .0f ) ;
3109+ smem_B[half_warp_lane + (((batch_size_warps*ticktock)+ half_warp_id) *b_tile_offset) + (col*16 )] = 0 .0f ;
31233110 }
31243111 ticktock = ticktock == 0 ? 1 : 0 ;
31253112
31263113 // for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
3127- for (int base_idx = 0 ; base_idx < K; base_idx+=blockDim .x -32 )
3114+ for (int base_idx = blockDim . x - 32 ; base_idx < K; base_idx+=blockDim .x -32 )
31283115 {
31293116 idx = base_idx + threadIdx .x ;
31303117
@@ -3156,7 +3143,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
31563143 for (int col = 0 ; col < 32 ; col++)
31573144 smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16 )] = 0 .0f ;
31583145 }
3159- // ticktock = ticktock == 0 ? 1 : 0;
3146+ ticktock = ticktock == 0 ? 1 : 0 ;
31603147
31613148 __syncthreads ();
31623149 if (warp_id == (WARPS-1 ))
@@ -3168,14 +3155,15 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
31683155 }
31693156 }
31703157
3171- // __syncthreads();
3172- // if(warp_id == (WARPS-1))
3173- // for(int k = 0; k < batch_size_warps; k++)
3174- // {
3175- // wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
3176- // wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
3177- // wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
3178- // }
3158+ __syncthreads ();
3159+ ticktock = ticktock == 0 ? 1 : 0 ;
3160+ if (warp_id == (WARPS-1 ))
3161+ for (int k = 0 ; k < batch_size_warps; k++)
3162+ {
3163+ wmma::load_matrix_sync (a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16 ); // 111 mu
3164+ wmma::load_matrix_sync (b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16 ); // 35 mu
3165+ wmma::mma_sync (c_frag, a_frag, b_frag, c_frag);
3166+ }
31793167 __syncthreads ();
31803168
31813169 // 129 mu
0 commit comments