@@ -749,6 +749,8 @@ torch::Tensor mbwq_linear_q4_forward_cuda(
749749 int bits
750750){
751751 const at::cuda::OptionalCUDAGuard device_guard (device_of (x));
752+ cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle ();
753+
752754 TORCH_CHECK (x.dtype () == torch::kHalf );
753755 TORCH_CHECK (x.size (1 ) == qweight.size (0 ) * (32 / bits));
754756
@@ -760,50 +762,72 @@ torch::Tensor mbwq_linear_q4_forward_cuda(
760762 auto option_output = torch::TensorOptions ().dtype (x.dtype ()).device (x.device ());
761763 auto out = torch::zeros ({size_m, size_n}, option_output);
762764
763- bool is_q_perm_all_zeros = torch::all (q_perm == 0 ).item <bool >();
764- auto perm_value = is_q_perm_all_zeros ? nullptr : reinterpret_cast <uint16_t *>(q_perm.data_ptr ());
765+ if (size_m > MAX_Q_GEMM_ROWS){
766+ // Reconstruct FP16 matrix and using cuBLAS for gemm
767+ auto fp_w = mbwq_linear_q42fp_weight_cuda (qweight,
768+ scales,
769+ zeros,
770+ group_size,
771+ bits,
772+ q_perm);
765773
766- dim3 blockDim , gridDim ;
767- blockDim .x = GPTQ_BLOCK_KN_SIZE;
768- blockDim .y = 1 ;
769- blockDim .z = 1 ;
770- gridDim .x = DIVIDE (size_n, GPTQ_BLOCK_KN_SIZE * 4 );
771- gridDim .y = DIVIDE (size_m, GPTQ_BLOCK_M_SIZE_MAX);
772- gridDim .z = DIVIDE (size_k, GPTQ_BLOCK_KN_SIZE);
774+ const half alpha = __float2half (1 .0f );
775+ const half beta = __float2half (0 .0f );
776+ cublasHgemm (cublas_handle,
777+ CUBLAS_OP_N,
778+ CUBLAS_OP_N,
779+ size_n, size_m, size_k,
780+ &alpha, reinterpret_cast <half *>(fp_w.data_ptr ()), size_n,
781+ reinterpret_cast <half *>(x.data_ptr ()), size_k,
782+ &beta, reinterpret_cast <half *>(out.data_ptr ()), size_n);
773783
774- if (bits == 4 ){
775- gemm_half_q4_half_gptq_kernel<GPTQ_BLOCK_M_SIZE_MAX><<<gridDim , blockDim >>> (
776- reinterpret_cast <half *>(x.data_ptr ()),
777- reinterpret_cast <uint32_t *>(qweight.data_ptr ()),
778- reinterpret_cast <half *>(zeros.data_ptr ()),
779- reinterpret_cast <half *>(scales.data_ptr ()),
780- reinterpret_cast <half *>(out.data_ptr ()),
781- size_m,
782- size_n,
783- size_k,
784- groups,
785- group_size,
786- true ,
787- perm_value
788- );
789- } else if (bits == 2 ){
790- gemm_half_q2_half_gptq_kernel<GPTQ_BLOCK_M_SIZE_MAX><<<gridDim , blockDim >>> (
791- reinterpret_cast <half *>(x.data_ptr ()),
792- reinterpret_cast <uint32_t *>(qweight.data_ptr ()),
793- reinterpret_cast <half *>(zeros.data_ptr ()),
794- reinterpret_cast <half *>(scales.data_ptr ()),
795- reinterpret_cast <half *>(out.data_ptr ()),
796- size_m,
797- size_n,
798- size_k,
799- groups,
800- group_size,
801- true ,
802- perm_value
803- );
804- } else {
805- std::cerr << " Error: weight bit width:" << bits <<" has not been supported yet!" << std::endl;
806- exit (EXIT_FAILURE);
784+ }else {
785+
786+ bool is_q_perm_all_zeros = torch::all (q_perm == 0 ).item <bool >();
787+ auto perm_value = is_q_perm_all_zeros ? nullptr : reinterpret_cast <uint16_t *>(q_perm.data_ptr ());
788+
789+ dim3 blockDim , gridDim ;
790+ blockDim .x = GPTQ_BLOCK_KN_SIZE;
791+ blockDim .y = 1 ;
792+ blockDim .z = 1 ;
793+ gridDim .x = DIVIDE (size_n, GPTQ_BLOCK_KN_SIZE * 4 );
794+ gridDim .y = DIVIDE (size_m, GPTQ_BLOCK_M_SIZE_MAX);
795+ gridDim .z = DIVIDE (size_k, GPTQ_BLOCK_KN_SIZE);
796+
797+ if (bits == 4 ){
798+ gemm_half_q4_half_gptq_kernel<GPTQ_BLOCK_M_SIZE_MAX><<<gridDim , blockDim >>> (
799+ reinterpret_cast <half *>(x.data_ptr ()),
800+ reinterpret_cast <uint32_t *>(qweight.data_ptr ()),
801+ reinterpret_cast <half *>(zeros.data_ptr ()),
802+ reinterpret_cast <half *>(scales.data_ptr ()),
803+ reinterpret_cast <half *>(out.data_ptr ()),
804+ size_m,
805+ size_n,
806+ size_k,
807+ groups,
808+ group_size,
809+ true ,
810+ perm_value
811+ );
812+ } else if (bits == 2 ){
813+ gemm_half_q2_half_gptq_kernel<GPTQ_BLOCK_M_SIZE_MAX><<<gridDim , blockDim >>> (
814+ reinterpret_cast <half *>(x.data_ptr ()),
815+ reinterpret_cast <uint32_t *>(qweight.data_ptr ()),
816+ reinterpret_cast <half *>(zeros.data_ptr ()),
817+ reinterpret_cast <half *>(scales.data_ptr ()),
818+ reinterpret_cast <half *>(out.data_ptr ()),
819+ size_m,
820+ size_n,
821+ size_k,
822+ groups,
823+ group_size,
824+ true ,
825+ perm_value
826+ );
827+ } else {
828+ std::cerr << " Error: weight bit width:" << bits <<" has not been supported yet!" << std::endl;
829+ exit (EXIT_FAILURE);
830+ }
807831 }
808832
809833 return out;
0 commit comments