@@ -3058,8 +3058,8 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
30583058 const int half_warp_lane = threadIdx .x % 16 ;
30593059 const int batch_size_warps = (WARPS-1 )*2 ;
30603060
3061- T local_A[1 ];
3062- T local_B[32 ];
3061+ T local_A[2 ];
3062+ T local_B[64 ];
30633063
30643064 const int a_tile_offset = 16 ;
30653065 const int b_tile_offset = (16 *32 + 16 );
@@ -3075,14 +3075,32 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
30753075
30763076 int ticktock = 0 ;
30773077 int idx = 0 + threadIdx .x ;
3078+ int loaded_values = 0 ;
30783079 // prefetch
30793080 if (idx < K && warp_id < (WARPS-1 ))
30803081 {
3081- local_A[0 ] = A[idx];
3082+ if (loaded_values == 0 )
3083+ {
3084+ local_A[0 ] = A[idx];
3085+ local_A[1 ] = A[idx+blockDim .x -32 ];
30823086
3083- #pragma unroll 32
3084- for (int col = 0 ; col < 32 ; col++)
3085- local_B[col] = B[(col_offset+col)*ldb+idx];
3087+ #pragma unroll 32
3088+ for (int col = 0 ; col < 32 ; col++)
3089+ {
3090+ local_B[col] = B[(col_offset+col)*ldb+idx];
3091+ local_B[col+32 ] = B[(col_offset+col)*ldb+idx+blockDim .x -32 ];
3092+ }
3093+ loaded_values = 1 ;
3094+ }
3095+ else
3096+ {
3097+ local_A[0 ] = local_A[1 ];
3098+ loaded_values--;
3099+
3100+ #pragma unroll 32
3101+ for (int col = 0 ; col < 32 ; col++)
3102+ local_B[col] = local_B[col+32 ];
3103+ }
30863104
30873105 smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0 ];
30883106
@@ -3113,11 +3131,35 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
31133131 __syncthreads ();
31143132 if (idx < K && warp_id < (WARPS-1 ))
31153133 {
3116- local_A[0 ] = A[idx];
3134+ // local_A[0] = A[idx];
31173135
3118- #pragma unroll 32
3119- for (int col = 0 ; col < 32 ; col++)
3120- local_B[col] = B[(col_offset+col)*ldb+idx];
3136+ // #pragma unroll 32
3137+ // for(int col = 0; col < 32; col++)
3138+ // local_B[col] = B[(col_offset+col)*ldb+idx];
3139+ if (loaded_values == 0 )
3140+ {
3141+ local_A[0 ] = A[idx];
3142+ local_A[1 ] = A[idx+blockDim .x -32 ];
3143+
3144+ #pragma unroll 32
3145+ for (int col = 0 ; col < 32 ; col++)
3146+ {
3147+ local_B[col] = B[(col_offset+col)*ldb+idx];
3148+ local_B[col+32 ] = B[(col_offset+col)*ldb+idx+blockDim .x -32 ];
3149+ }
3150+ loaded_values = 1 ;
3151+ }
3152+ else
3153+ {
3154+ local_A[0 ] = local_A[1 ];
3155+ loaded_values--;
3156+
3157+ #pragma unroll 32
3158+ for (int col = 0 ; col < 32 ; col++)
3159+ local_B[col] = local_B[col+32 ];
3160+
3161+
3162+ }
31213163
31223164 smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0 ];
31233165
0 commit comments