You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
// these are not used and make no sense, but the compiler needs them
3590
3555
//template __global__ void gemm_device<float, 16, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
@@ -3611,9 +3576,6 @@ template __global__ void gemm_device<half, 16, 96>(int M, int N, int K, half * _
3611
3576
template __global__void kgemm_4bit_inference<half, 128>(int M, int N, int K, half * __restrict__const A, unsignedchar *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
3612
3577
template __global__void kgemm_4bit_inference<half, 160>(int M, int N, int K, half * __restrict__const A, unsignedchar *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
3613
3578
3614
-
3615
-
//template __global__ void kMatmul_inference_4bit<NF4, half, half, half>(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB);
Copy file name to clipboardExpand all lines: csrc/kernels.cuh
+2-16Lines changed: 2 additions & 16 deletions
Original file line number
Diff line number
Diff line change
@@ -122,23 +122,9 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
122
122
123
123
template <int FORMAT> __global__voidkExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
124
124
125
-
//template <class MShape, class NShape, class KShape,
126
-
// class TA, class AStride, class ABlockLayout, class AThreadLayout,
127
-
// class TB, class BStride, class BBlockLayout, class BThreadLayout,
128
-
// class TC, class CStride, class CBlockLayout, class CThreadLayout,
template <typename T, int BITS, int THREADS> __global__voidgemm_device(int M, int N, int K, T * __restrict__const A, T* B, T * out, int lda, int ldb, int ldc);
142
126
template <typename T, int THREADS> __global__voidkgemm_4bit_inference(int M, int N, int K, T * __restrict__const A, unsignedchar *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
143
127
128
+
template <typename T, int FUNC> __global__voidkfunc(T *A, T *B, T value, long n);
template void gemm_4bit_inference<half>(int m, int n, int k, half * A, unsignedchar* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
725
730
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
726
731
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);
Copy file name to clipboardExpand all lines: csrc/ops.cuh
+8-1Lines changed: 8 additions & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -93,6 +93,13 @@ typedef enum DataType_t
93
93
NF4 = 2,
94
94
} DataType_t;
95
95
96
+
typedefenum Funcs_t
97
+
{
98
+
FILL = 0,
99
+
ARANGE = 1,
100
+
_MUL = 2,
101
+
} Funcs_t;
102
+
96
103
classContext
97
104
{
98
105
public:
@@ -193,6 +200,6 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows
193
200
template <typename T> voidgemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits);
194
201
template <typename T> voidgemm_4bit_inference(int m, int n, int k, T * A, unsignedchar* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
195
202
203
+
template <typename T, int FUNC> voidfunc(T *A, T *B, T value, long n);
196
204
197
-
voidpipeline_test(float *A, float *B, size_t n, size_t batch_size);
0 commit comments