You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
template <typename T, int BITS, int THREADS> __global__voidgemm_device(int M, int N, int K, T * __restrict__const A, T* B, T * out, int lda, int ldb, int ldc)
3046
3046
{
3047
3047
@@ -3061,23 +3061,18 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
3061
3061
T local_A[1];
3062
3062
T local_B[32];
3063
3063
3064
-
constint a_tile_offset = (16 + 16);
3064
+
constint a_tile_offset = 16;
3065
3065
constint b_tile_offset = (16*32 + 16);
3066
3066
3067
-
__shared__ T smem_A[8*16 + (4*16*(batch_size_warps-1))];
3067
+
__shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))];
3068
3068
__shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))];
template <typename T, int THREADS> __global__voidkgemm_4bit_inference(int M, int N, int K, T * __restrict__const A, unsignedchar *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
//template __global__ void gemm_device<float, 16, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
3497
3453
template __global__void gemm_device<half, 32, 256>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
3498
3454
template __global__void gemm_device<half, 32, 192>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
3455
+
template __global__void gemm_device<half, 32, 160>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
3499
3456
template __global__void gemm_device<half, 32, 128>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
3500
3457
//template __global__ void gemm_device<float, 16, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
3501
3458
template __global__void gemm_device<half, 32, 32>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
@@ -3506,6 +3463,7 @@ template __global__ void gemm_device<half, 32, 96>(int M, int N, int K, half * _
3506
3463
//template __global__ void gemm_device<float, 32, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
3507
3464
template __global__void gemm_device<half, 16, 256>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
3508
3465
template __global__void gemm_device<half, 16, 192>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
3466
+
template __global__void gemm_device<half, 16, 160>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
3509
3467
template __global__void gemm_device<half, 16, 128>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
3510
3468
//template __global__ void gemm_device<float, 32, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
3511
3469
template __global__void gemm_device<half, 16, 32>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
0 commit comments