You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
template <typename T, int BITS, int THREADS> __global__voidgemm_device(int M, int N, int K, T * __restrict__const A, T* B, T * out, int lda, int ldb, int ldc)
3046
3046
{
3047
3047
@@ -3056,17 +3056,18 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
3056
3056
constint warp_id = threadIdx.x / 32;
3057
3057
constint half_warp_id = threadIdx.x / 16;
3058
3058
constint half_warp_lane = threadIdx.x % 16;
3059
+
constint batch_size_warps = (WARPS-1)*2;
3059
3060
3060
3061
T local_A[1];
3061
3062
T local_B[8];
3062
3063
3063
-
constint a_tile_offset = 32*16 + 16;
3064
-
constint b_tile_offset = 16*8 + 16;
3064
+
constint a_tile_offset = (32*16 + 16);
3065
+
constint b_tile_offset = (16*8 + 16);
3065
3066
constint c_tile_offset = 32*8 + 24;
3066
3067
3067
-
__shared__ T smem_A[WARPS*32*16 + (16*(WARPS-1))];
3068
-
__shared__ T smem_B[WARPS*16*8 + (16*(WARPS-1))];
3069
-
__shared__ T smem_C[WARPS*32*8 + (24*(WARPS-1))];
3068
+
__shared__ T smem_A[2*batch_size_warps*32*16 + (2*16*(batch_size_warps-1))];
3069
+
__shared__ T smem_B[2*batch_size_warps*16*8 + (2*16*(batch_size_warps-1))];
// these are not used and make no sense, but the compiler needs them
3465
3471
//template __global__ void gemm_device<float, 16, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
3472
+
template __global__void gemm_device<half, 32, 256>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
3466
3473
template __global__void gemm_device<half, 32, 128>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
3467
3474
//template __global__ void gemm_device<float, 16, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
3468
3475
template __global__void gemm_device<half, 32, 32>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
3469
3476
template __global__void gemm_device<half, 32, 64>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
3470
3477
// these are not used and make no sense, but the compiler needs them
3471
3478
3472
3479
//template __global__ void gemm_device<float, 32, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
3480
+
template __global__void gemm_device<half, 16, 256>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
3473
3481
template __global__void gemm_device<half, 16, 128>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
3474
3482
//template __global__ void gemm_device<float, 32, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
3475
3483
template __global__void gemm_device<half, 16, 32>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
Copy file name to clipboardExpand all lines: csrc/ops.cu
+3-2Lines changed: 3 additions & 2 deletions
Original file line number
Diff line number
Diff line change
@@ -692,9 +692,10 @@ template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out
692
692
//gemm_device<T, 32, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
693
693
//gemm_device<T, 32, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
694
694
if(bits == 16)
695
-
//gemm_device<T, 16, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
695
+
//gemm_device<T, 16, 256><<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
696
+
gemm_device<T, 16, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
696
697
//gemm_device<T, 16, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
697
-
gemm_device<T, 16, 64><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
698
+
//gemm_device<T, 16, 64><<< num_blocks, 64, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
698
699
}
699
700
700
701
template <typename T> voidgemm_4bit_inference(int m, int n, int k, T * A, unsignedchar* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
0 commit comments