Skip to content

Commit 65df40d

Browse files
Haojin YangHaojin Yang
authored andcommitted
Improved the performance of mbwq_linear_q4_forward_cuda for long seq-length.
1 parent b5bfedf commit 65df40d

File tree

3 files changed

+74
-44
lines changed

3 files changed

+74
-44
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/)
55
and this project adheres to [Semantic Versioning](http://semver.org/).
66

77

8+
## [0.2.3] - 2024/05/01
9+
10+
### Updated
11+
12+
- Enhanced the performance of the MBWQ linear layer for processing long sequences, addressing previous inefficiencies.
13+
814
## [0.2.2] - 2024/04/29
915

1016
### Updated

bitorch_engine/layers/qlinear/nbit/cuda/mbwq_linear_cuda_kernel.cu

Lines changed: 66 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,8 @@ torch::Tensor mbwq_linear_q4_forward_cuda(
749749
int bits
750750
){
751751
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
752+
cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle();
753+
752754
TORCH_CHECK(x.dtype() == torch::kHalf);
753755
TORCH_CHECK(x.size(1) == qweight.size(0) * (32 / bits));
754756

@@ -760,50 +762,72 @@ torch::Tensor mbwq_linear_q4_forward_cuda(
760762
auto option_output = torch::TensorOptions().dtype(x.dtype()).device(x.device());
761763
auto out = torch::zeros({size_m, size_n}, option_output);
762764

763-
bool is_q_perm_all_zeros = torch::all(q_perm == 0).item<bool>();
764-
auto perm_value = is_q_perm_all_zeros ? nullptr : reinterpret_cast<uint16_t *>(q_perm.data_ptr());
765+
if (size_m > MAX_Q_GEMM_ROWS){
766+
// Reconstruct FP16 matrix and using cuBLAS for gemm
767+
auto fp_w = mbwq_linear_q42fp_weight_cuda(qweight,
768+
scales,
769+
zeros,
770+
group_size,
771+
bits,
772+
q_perm);
765773

766-
dim3 blockDim, gridDim;
767-
blockDim.x = GPTQ_BLOCK_KN_SIZE;
768-
blockDim.y = 1;
769-
blockDim.z = 1;
770-
gridDim.x = DIVIDE(size_n, GPTQ_BLOCK_KN_SIZE * 4);
771-
gridDim.y = DIVIDE(size_m, GPTQ_BLOCK_M_SIZE_MAX);
772-
gridDim.z = DIVIDE(size_k, GPTQ_BLOCK_KN_SIZE);
774+
const half alpha = __float2half(1.0f);
775+
const half beta = __float2half(0.0f);
776+
cublasHgemm(cublas_handle,
777+
CUBLAS_OP_N,
778+
CUBLAS_OP_N,
779+
size_n, size_m, size_k,
780+
&alpha, reinterpret_cast<half *>(fp_w.data_ptr()), size_n,
781+
reinterpret_cast<half *>(x.data_ptr()), size_k,
782+
&beta, reinterpret_cast<half *>(out.data_ptr()), size_n);
773783

774-
if (bits == 4){
775-
gemm_half_q4_half_gptq_kernel<GPTQ_BLOCK_M_SIZE_MAX><<<gridDim, blockDim>>>(
776-
reinterpret_cast<half *>(x.data_ptr()),
777-
reinterpret_cast<uint32_t *>(qweight.data_ptr()),
778-
reinterpret_cast<half *>(zeros.data_ptr()),
779-
reinterpret_cast<half *>(scales.data_ptr()),
780-
reinterpret_cast<half *>(out.data_ptr()),
781-
size_m,
782-
size_n,
783-
size_k,
784-
groups,
785-
group_size,
786-
true,
787-
perm_value
788-
);
789-
} else if (bits == 2){
790-
gemm_half_q2_half_gptq_kernel<GPTQ_BLOCK_M_SIZE_MAX><<<gridDim, blockDim>>>(
791-
reinterpret_cast<half *>(x.data_ptr()),
792-
reinterpret_cast<uint32_t *>(qweight.data_ptr()),
793-
reinterpret_cast<half *>(zeros.data_ptr()),
794-
reinterpret_cast<half *>(scales.data_ptr()),
795-
reinterpret_cast<half *>(out.data_ptr()),
796-
size_m,
797-
size_n,
798-
size_k,
799-
groups,
800-
group_size,
801-
true,
802-
perm_value
803-
);
804-
} else {
805-
std::cerr << "Error: weight bit width:"<< bits <<" has not been supported yet!" << std::endl;
806-
exit(EXIT_FAILURE);
784+
}else{
785+
786+
bool is_q_perm_all_zeros = torch::all(q_perm == 0).item<bool>();
787+
auto perm_value = is_q_perm_all_zeros ? nullptr : reinterpret_cast<uint16_t *>(q_perm.data_ptr());
788+
789+
dim3 blockDim, gridDim;
790+
blockDim.x = GPTQ_BLOCK_KN_SIZE;
791+
blockDim.y = 1;
792+
blockDim.z = 1;
793+
gridDim.x = DIVIDE(size_n, GPTQ_BLOCK_KN_SIZE * 4);
794+
gridDim.y = DIVIDE(size_m, GPTQ_BLOCK_M_SIZE_MAX);
795+
gridDim.z = DIVIDE(size_k, GPTQ_BLOCK_KN_SIZE);
796+
797+
if (bits == 4){
798+
gemm_half_q4_half_gptq_kernel<GPTQ_BLOCK_M_SIZE_MAX><<<gridDim, blockDim>>>(
799+
reinterpret_cast<half *>(x.data_ptr()),
800+
reinterpret_cast<uint32_t *>(qweight.data_ptr()),
801+
reinterpret_cast<half *>(zeros.data_ptr()),
802+
reinterpret_cast<half *>(scales.data_ptr()),
803+
reinterpret_cast<half *>(out.data_ptr()),
804+
size_m,
805+
size_n,
806+
size_k,
807+
groups,
808+
group_size,
809+
true,
810+
perm_value
811+
);
812+
} else if (bits == 2){
813+
gemm_half_q2_half_gptq_kernel<GPTQ_BLOCK_M_SIZE_MAX><<<gridDim, blockDim>>>(
814+
reinterpret_cast<half *>(x.data_ptr()),
815+
reinterpret_cast<uint32_t *>(qweight.data_ptr()),
816+
reinterpret_cast<half *>(zeros.data_ptr()),
817+
reinterpret_cast<half *>(scales.data_ptr()),
818+
reinterpret_cast<half *>(out.data_ptr()),
819+
size_m,
820+
size_n,
821+
size_k,
822+
groups,
823+
group_size,
824+
true,
825+
perm_value
826+
);
827+
} else {
828+
std::cerr << "Error: weight bit width:"<< bits <<" has not been supported yet!" << std::endl;
829+
exit(EXIT_FAILURE);
830+
}
807831
}
808832

809833
return out;

tests/layers/test_nbit_linear.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def test_mbwq_linear_q4_cuda(num_input_features, num_hidden_fc, batch_size, grou
362362
mbwq_linear_layer.zeros, group_size, mbwq_linear_layer.w_bit, mbwq_linear_layer.q_perm)
363363
result1_pt = torch.matmul(input_data_cuda, fp_weights)
364364

365-
assert torch.all(torch.isclose(result1, result1_pt, rtol=10, atol=10, equal_nan=False))
365+
assert torch.mean(torch.abs(result1 - result1_pt)).item() < 2
366366

367367
mpq_linear_layer = MPQLinearCuda(in_channels=num_input_features,
368368
out_channels=num_hidden_fc,
@@ -401,5 +401,5 @@ def test_mbwq_linear_q4_cuda(num_input_features, num_hidden_fc, batch_size, grou
401401
time_engine = time.time() - start_time
402402
print("bitorch-engine mpq_linear forward (CUDA) run time: %.6f s" % (time_engine/num_runs))
403403

404-
assert torch.all(torch.isclose(result1, result2, rtol=10, atol=10, equal_nan=False))
404+
assert torch.mean(torch.abs(result1 - result2)).item() < 3
405405

0 commit comments

Comments
 (0)