Skip to content

Commit 17ae0fa

Browse files
author
Haojin Yang
committed
Update sha, simplified grad calculation, adapted mbwq linear cuda kernel.
1 parent 65df40d commit 17ae0fa

File tree

4 files changed

+7
-26
lines changed

4 files changed

+7
-26
lines changed

bitorch_engine/layers/qlinear/nbit/cuda/mbwq_layer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,8 @@ def backward(ctx: torch.autograd.function.BackwardCFunction,
109109
grad_input = output_gradient.mm(weights.t()) # (m, n)*(n, k) = (m, k)
110110
#======================================================================================================#
111111

112-
# (n, m) * (m, k) = (n, k)
113112
if qweight.requires_grad: # This additional check is required by peft training.
114-
qweight.privileged_grad = output_gradient.t().mm(input).t() # (k, n)
113+
qweight.privileged_grad = input.t().mm(output_gradient) # (k, m) * (m, n) = (k, n)
115114

116115
grad_input = unflatten_x(grad_input, shape)
117116

bitorch_engine/layers/qlinear/nbit/cuda/mbwq_linear_cuda_kernel.cu

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,6 @@ torch::Tensor mbwq_linear_q4_forward_cuda(
749749
int bits
750750
){
751751
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
752-
cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle();
753752

754753
TORCH_CHECK(x.dtype() == torch::kHalf);
755754
TORCH_CHECK(x.size(1) == qweight.size(0) * (32 / bits));
@@ -770,16 +769,8 @@ torch::Tensor mbwq_linear_q4_forward_cuda(
770769
group_size,
771770
bits,
772771
q_perm);
773-
774-
const half alpha = __float2half(1.0f);
775-
const half beta = __float2half(0.0f);
776-
cublasHgemm(cublas_handle,
777-
CUBLAS_OP_N,
778-
CUBLAS_OP_N,
779-
size_n, size_m, size_k,
780-
&alpha, reinterpret_cast<half *>(fp_w.data_ptr()), size_n,
781-
reinterpret_cast<half *>(x.data_ptr()), size_k,
782-
&beta, reinterpret_cast<half *>(out.data_ptr()), size_n);
772+
// indirectly use cublas through torch matmul api
773+
out = torch::matmul(x, fp_w.to(option_output));
783774

784775
}else{
785776

@@ -943,7 +934,6 @@ torch::Tensor mbwq_linear_exl2_forward_cuda(
943934
bool use_cublas
944935
){
945936
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
946-
cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle();
947937
TORCH_CHECK(x.dtype() == torch::kHalf);
948938

949939
int size_m = x.size(0); // m
@@ -963,15 +953,8 @@ torch::Tensor mbwq_linear_exl2_forward_cuda(
963953
qgroup_map,
964954
rows);
965955

966-
const half alpha = __float2half(1.0f);
967-
const half beta = __float2half(0.0f);
968-
cublasHgemm(cublas_handle,
969-
CUBLAS_OP_N,
970-
CUBLAS_OP_N,
971-
size_n, size_m, size_k,
972-
&alpha, reinterpret_cast<half *>(fp_w.data_ptr()), size_n,
973-
reinterpret_cast<half *>(x.data_ptr()), size_k,
974-
&beta, reinterpret_cast<half *>(out.data_ptr()), size_n);
956+
// indirectly use cublas through torch matmul api
957+
out = torch::matmul(x, fp_w.to(option_output));
975958

976959
}else{
977960
int rows_8 = rows[0];

bitorch_engine/layers/qlinear/nbit/cuda/mpq_layer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,8 @@ def backward(ctx: torch.autograd.function.BackwardCFunction,
100100
output_gradient, a_bit, w_bit, asym)
101101
#==================================================================#
102102

103-
# (n, m) * (m, k) = (n, k)
104103
if qweight.requires_grad: # This additional check is required by peft training.
105-
qweight.privileged_grad = output_gradient.t().mm(input).t() # (k, n)
104+
qweight.privileged_grad = input.t().mm(output_gradient) # (k, m) * (m, n) = (k, n)
106105

107106
grad_input = unflatten_x(grad_input, shape)
108107

docker/build_scripts/install_modified_pytorch.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ fi
2424
if [ "${from_image}" == "pytorch/pytorch:2.2.0-cuda12.1-cudnn8-devel" ]; then
2525
gdrive_id="1LjFNImboq8QeFSompMS2gPjBRYtP2Dsz"
2626
file="torch-2.2.2-cp310-cp310-linux_x86_64.whl"
27-
checksum="2a5953dab7be6c1640112e38ae7519ad88180d9fa79faab6c86dbee6b1cc210e"
27+
checksum="bcc0ba7f121ee2f42ed0a59f01d4e3d70f82a8981be0be25c5e0fe0635a54b2d"
2828
fi
2929
#if [ "${from_image}" == "pytorch/pytorch:X.X.X-cudaXX.X-cudnn8-devel" ]; then
3030
# gdrive_id="xxx"

0 commit comments

Comments
 (0)