@@ -3053,25 +3053,23 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
30533053 // // Allocate shared memory for BlockReduce
30543054 // __shared__ typename BlockReduce::TempStorage reduce;
30553055 int col_offset = blockIdx .x *8 ;
3056+ const int warp_id = threadIdx .x / 32 ;
30563057 const int half_warp_id = threadIdx .x / 16 ;
30573058 const int half_warp_lane = threadIdx .x % 16 ;
30583059
3059- T local_A[64 /BITS];
3060- T local_B[64 /BITS];
3061- T local_C[8 ];
3060+ T local_A[1 ];
3061+ T local_B[8 ];
30623062
30633063 const int a_tile_offset = 32 *16 + 16 ;
30643064 const int b_tile_offset = 16 *8 + 16 ;
30653065 const int c_tile_offset = 32 *8 + 24 ;
30663066
3067- __shared__ T smem_A[WARPS*32 *16 * 2 + (16 *(WARPS-1 ))];
3068- __shared__ T smem_B[WARPS*16 *8 * 2 + (16 *(WARPS-1 ))];
3067+ __shared__ T smem_A[WARPS*32 *16 + (16 *(WARPS-1 ))];
3068+ __shared__ T smem_B[WARPS*16 *8 + (16 *(WARPS-1 ))];
30693069 __shared__ T smem_C[WARPS*32 *8 + (24 *(WARPS-1 ))];
30703070
30713071 wmma::fragment<wmma::matrix_a, 32 , 8 , 16 , half, wmma::row_major> a_frag;
30723072 wmma::fragment<wmma::matrix_b, 32 , 8 , 16 , half, wmma::col_major> b_frag;
3073- wmma::fragment<wmma::matrix_a, 32 , 8 , 16 , half, wmma::row_major> a2_frag;
3074- wmma::fragment<wmma::matrix_b, 32 , 8 , 16 , half, wmma::col_major> b2_frag;
30753073 wmma::fragment<wmma::accumulator, 32 , 8 , 16 , half> c_frag;
30763074
30773075 wmma::fill_fragment (c_frag, 0 .0f );
@@ -3087,37 +3085,55 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
30873085 smem_C[i] = T (0 );
30883086 __syncthreads ();
30893087
3090- #pragma unroll 8
3091- for (int k = 0 ; k < 8 ; k++)
3092- local_C[k] = T (0 );
3088+ // #pragma unroll 8
3089+ // for(int k = 0; k < 8; k++)
3090+ // local_C[k] = T(0);
30933091
30943092 // int block_idx = 0;
30953093 // for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x)
30963094 for (int base_idx = 0 ; base_idx < K; base_idx+=blockDim .x )
30973095 {
30983096 int idx = base_idx + threadIdx .x ;
30993097
3100- if (idx >= K )
3098+ for ( int k = 0 ; k < 2 ; k++ )
31013099 {
3102- smem_A[threadIdx .x ] = 0 .0f ;
3103- // smem_B[threadIdx.x] = 0.0f;
3104- }
3105- else
3106- {
3107- smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = A[idx];
3100+ if (k == 0 )
3101+ {
3102+ if (idx < K)
3103+ {
3104+ local_A[0 ] = A[idx];
31083105
3109- for (int col = 0 ; col < 8 ; col++)
3110- smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16 )] = B[(col_offset+col)*ldb+idx];
3111- }
3106+ #pragma unroll 8
3107+ for (int col = 0 ; col < 8 ; col++)
3108+ local_B[col] = B[(col_offset+col)*ldb+idx];
3109+ }
31123110
3113- __syncthreads ();
3111+ }
3112+
3113+ if (idx >= K)
3114+ {
3115+ smem_A[threadIdx .x ] = 0 .0f ;
3116+ // smem_B[threadIdx.x] = 0.0f;
3117+ }
3118+ else
3119+ {
3120+ if ((k == 0 && half_warp_id % 2 == 0 ) ||
3121+ (k == 1 && half_warp_id % 2 == 1 ))
3122+ {
3123+ smem_A[half_warp_lane + (warp_id*a_tile_offset)] = local_A[0 ];
31143124
3115- wmma::load_matrix_sync (a_frag, &(smem_A[0 ]), 16 ); // 111 mu
3116- wmma::load_matrix_sync (b_frag, &(smem_B[0 ]), 16 ); // 35 mu
3117- wmma::load_matrix_sync (a2_frag, &(smem_A[half_warp_id*a_tile_offset]), 16 ); // 111 mu
3118- wmma::load_matrix_sync (b2_frag, &(smem_B[half_warp_id*b_tile_offset]), 16 ); // 35 mu
3119- wmma::mma_sync (c_frag, a_frag, b_frag, c_frag);
3120- wmma::mma_sync (c_frag, a2_frag, b2_frag, c_frag);
3125+ #pragma unroll 8
3126+ for (int col = 0 ; col < 8 ; col++)
3127+ smem_B[half_warp_lane + (warp_id*b_tile_offset) + (col*16 )] = local_B[col];
3128+ }
3129+ }
3130+
3131+ __syncthreads ();
3132+
3133+ wmma::load_matrix_sync (a_frag, &(smem_A[warp_id*a_tile_offset]), 16 ); // 111 mu
3134+ wmma::load_matrix_sync (b_frag, &(smem_B[warp_id*b_tile_offset]), 16 ); // 35 mu
3135+ wmma::mma_sync (c_frag, a_frag, b_frag, c_frag);
3136+ }
31213137 }
31223138
31233139 // 129 mu
0 commit comments