Skip to content

Commit 77f15fd

Browse files
committed
Shared memory efficient 240.
1 parent 89cccd8 commit 77f15fd

File tree

3 files changed

+22
-64
lines changed

3 files changed

+22
-64
lines changed

csrc/kernels.cu

Lines changed: 19 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3041,7 +3041,7 @@ template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_l
30413041
}
30423042
}
30433043

3044-
#define WARPS 6
3044+
#define WARPS 5
30453045
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc)
30463046
{
30473047

@@ -3061,23 +3061,18 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
30613061
T local_A[1];
30623062
T local_B[32];
30633063

3064-
const int a_tile_offset = (16 + 16);
3064+
const int a_tile_offset = 16;
30653065
const int b_tile_offset = (16*32 + 16);
30663066

3067-
__shared__ T smem_A[8*16 + (4*16*(batch_size_warps-1))];
3067+
__shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))];
30683068
__shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))];
3069-
__shared__ T smem_C[8*32];
3069+
//__shared__ T smem_C[8*32];
30703070

30713071
wmma::fragment<wmma::matrix_a, 8, 32, 16, half, wmma::row_major> a_frag;
30723072
wmma::fragment<wmma::matrix_b, 8, 32, 16, half, wmma::col_major> b_frag;
30733073
wmma::fragment<wmma::accumulator, 8, 32, 16, half> c_frag;
3074-
30753074
wmma::fill_fragment(c_frag, 0.0f);
30763075

3077-
for(int i = threadIdx.x; i < 8*32; i+=blockDim.x)
3078-
smem_C[i] = T(0);
3079-
__syncthreads();
3080-
30813076
int ticktock = 0;
30823077
int idx = 0 + threadIdx.x;
30833078
// prefetch
@@ -3155,63 +3150,24 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
31553150
}
31563151

31573152
__syncthreads();
3153+
if(warp_id != (WARPS-1)){ return; }
3154+
// only warp_id == (WARPS-1) from here
3155+
int warp_lane = threadIdx.x % 32;
3156+
31583157
ticktock = ticktock == 0 ? 1 : 0;
3159-
if(warp_id == (WARPS-1))
3160-
for(int k = 0; k < batch_size_warps; k++)
3161-
{
3162-
wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
3163-
wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
3164-
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
3165-
}
3166-
__syncthreads();
3158+
for(int k = 0; k < batch_size_warps; k++)
3159+
{
3160+
wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
3161+
wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
3162+
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
3163+
}
31673164

31683165
// 129 mu
31693166
if(warp_id == (WARPS-1))
3170-
wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major);
3171-
__syncthreads();
3172-
3167+
wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major);
31733168

3174-
//if(threadIdx.x >= 16){ return; }
3175-
//printf("%i %f\n", threadIdx.x, (float)smem_C[threadIdx.x]);
3176-
3177-
//if(threadIdx.x < 32)
3178-
//if(half_warp_lane < 8 && half_warp_id > 0)
3179-
// //local_C[warp_lane] = smem_C[warp_lane + (warp_id*32*8)];
3180-
// atomicAdd(&(smem_C[half_warp_lane]), smem_C[half_warp_lane + (half_warp_id*c_tile_offset)]);
3181-
//__syncthreads();
3182-
3183-
//local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum());
3184-
//if(threadIdx.x == 0)
3185-
// for(int row = 0; row < 32; row++)
3186-
// {
3187-
// printf("row %i ", row);
3188-
// for(int id = 0; id < 4; id++)
3189-
// {
3190-
// printf(" id %i: ", id);
3191-
// for(int k = 0; k < 8; k++)
3192-
// printf("%f ", (float)smem_C[k + (row*8) + (id*32*8)]);
3193-
// printf("\n");
3194-
// }
3195-
// }
3196-
3197-
//__syncthreads();
3198-
3199-
//if((float)local_C[0] !=0.0f)
3200-
// printf("%i %i %f\n", warp_lane, warp_id, (float)local_C[0]);
3201-
//local_C[0] = WarpReduce(temp_storage).Sum(local_C[0]);
3202-
3203-
//__syncwarp();
3204-
3205-
////for(int i = threadIdx.x; i < 32*8; i+=blockDim.x)
3206-
////{
3207-
// if((float)local_C[0] !=0.0f)
3208-
// printf("%i %f\n", 0, (float)local_C[0]);
3209-
//}
3210-
3211-
//if(threadIdx.x < 8 && col_offset + threadIdx.x < M)
3212-
//out[col_offset + threadIdx.x ] = smem_C[threadIdx.x];
3213-
if(threadIdx.x < 32 && col_offset + threadIdx.x < M)
3214-
out[col_offset + threadIdx.x] = smem_C[threadIdx.x];
3169+
if(col_offset + warp_lane < M)
3170+
out[col_offset + warp_lane] = smem_A[warp_lane];
32153171
}
32163172

32173173
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
@@ -3496,6 +3452,7 @@ __global__ void with_staging_unified(float const* global_in, float * global_out,
34963452
//template __global__ void gemm_device<float, 16, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
34973453
template __global__ void gemm_device<half, 32, 256>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
34983454
template __global__ void gemm_device<half, 32, 192>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
3455+
template __global__ void gemm_device<half, 32, 160>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
34993456
template __global__ void gemm_device<half, 32, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
35003457
//template __global__ void gemm_device<float, 16, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
35013458
template __global__ void gemm_device<half, 32, 32>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
@@ -3506,6 +3463,7 @@ template __global__ void gemm_device<half, 32, 96>(int M, int N, int K, half * _
35063463
//template __global__ void gemm_device<float, 32, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
35073464
template __global__ void gemm_device<half, 16, 256>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
35083465
template __global__ void gemm_device<half, 16, 192>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
3466+
template __global__ void gemm_device<half, 16, 160>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
35093467
template __global__ void gemm_device<half, 16, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
35103468
//template __global__ void gemm_device<float, 32, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
35113469
template __global__ void gemm_device<half, 16, 32>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);

csrc/ops.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,7 @@ template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out
693693
//gemm_device<T, 32, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
694694
if(bits == 16)
695695
//gemm_device<T, 16, 256><<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
696-
gemm_device<T, 16, 192><<< num_blocks, 192, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
696+
gemm_device<T, 16, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
697697
//gemm_device<T, 16, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
698698
//gemm_device<T, 16, 96><<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
699699
//gemm_device<T, 16, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);

tests/test_functional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2358,9 +2358,9 @@ def test_normal_map_tree():
23582358
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
23592359
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
23602360
def test_cutlass3_gemm(dtype):
2361-
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
2361+
for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
23622362
#for dim in [4096, 5120, 6656, 8192]:
2363-
for dim in [4096]:
2363+
#for dim in [4096]:
23642364
errs = []
23652365
relerrs = []
23662366
max_err = 0

0 commit comments

Comments
 (0)