You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
template <typename T, int BITS, int THREADS> __global__voidgemm_device(int M, int N, int K, T * __restrict__const A, T* B, T * out, int lda, int ldb, int ldc)
3046
3046
{
3047
3047
@@ -3052,26 +3052,26 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
// these are not used and make no sense, but the compiler needs them
3471
3471
//template __global__ void gemm_device<float, 16, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
3472
3472
template __global__void gemm_device<half, 32, 256>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
3473
+
template __global__void gemm_device<half, 32, 192>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
3473
3474
template __global__void gemm_device<half, 32, 128>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
3474
3475
//template __global__ void gemm_device<float, 16, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
3475
3476
template __global__void gemm_device<half, 32, 32>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
3476
3477
template __global__void gemm_device<half, 32, 64>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
3478
+
template __global__void gemm_device<half, 32, 96>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
3477
3479
// these are not used and make no sense, but the compiler needs them
3478
3480
3479
3481
//template __global__ void gemm_device<float, 32, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
3480
3482
template __global__void gemm_device<half, 16, 256>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
3483
+
template __global__void gemm_device<half, 16, 192>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
3481
3484
template __global__void gemm_device<half, 16, 128>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
3482
3485
//template __global__ void gemm_device<float, 32, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
3483
3486
template __global__void gemm_device<half, 16, 32>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
3484
3487
template __global__void gemm_device<half, 16, 64>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
3488
+
template __global__void gemm_device<half, 16, 96>(int M, int N, int K, half * __restrict__const A, half* B, half * out, int lda, int ldb, int ldc);
3485
3489
3486
3490
template __global__void kgemm_4bit_inference<half, 128>(int M, int N, int K, half * __restrict__const A, unsignedchar *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
0 commit comments