You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: csrc/ops.cu
-25Lines changed: 0 additions & 25 deletions
Original file line number
Diff line number
Diff line change
@@ -674,43 +674,18 @@ template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out
674
674
675
675
int num_blocks = (m+31)/32;
676
676
677
-
//cout << num_blocks << endl;
678
-
//cout << lda << endl;
679
-
//cout << ldb << endl;
680
-
//cout << ldc << endl;
681
-
682
-
//cout << m << endl;
683
-
//cout << n << endl;
684
-
//cout << k << endl;
685
677
if(bits == 32)
686
-
//gemm_device<T, 32, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
687
678
gemm_device<T, 32, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
688
679
if(bits == 16)
689
-
//gemm_device<T, 16, 256><<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
690
680
gemm_device<T, 16, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
691
-
//gemm_device<T, 16, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
692
-
//gemm_device<T, 16, 96><<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
693
-
//gemm_device<T, 16, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
694
-
//gemm_device<T, 16, 64><<< num_blocks, 64, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
695
681
}
696
682
697
683
template <typename T> voidgemm_4bit_inference(int m, int n, int k, T * A, unsignedchar* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
698
684
{
699
685
700
686
int num_blocks = (m+31)/32;
701
687
702
-
//cout << num_blocks << endl;
703
-
//cout << lda << endl;
704
-
//cout << ldb << endl;
705
-
//cout << ldc << endl;
706
-
707
-
//cout << m << endl;
708
-
//cout << n << endl;
709
-
//cout << k << endl;
710
688
kgemm_4bit_inference<T, 96><<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
711
-
//kgemm_4bit_inference<T, 256><<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
712
-
//kgemm_4bit_inference<T, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
713
-
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
714
689
}
715
690
716
691
template <typename T, int BITS> voidgemm_4bit_inference_naive(int m, int n, int k, T * A, unsignedchar* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream)
0 commit comments