Skip to content

Commit 8addc3e

Browse files
authored
Merge pull request #2 from GreenBitAI/haojin_dev
Haojin dev
2 parents 65df40d + e88b670 commit 8addc3e

File tree

10 files changed

+141
-47
lines changed

10 files changed

+141
-47
lines changed

CHANGELOG.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,23 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/)
55
and this project adheres to [Semantic Versioning](http://semver.org/).
66

77

8+
## [0.2.4] - 2024/05/23
9+
10+
### Added
11+
12+
- Tuned the hyperparameters of DiodeMix optimizer for sft.
13+
- Added sft-support for the classical gptq-style models.
14+
- Implemented qzeros update in finetuning process.
15+
16+
### Updated
17+
18+
- Extended pack_fp_weight function.
19+
- Enhanced the performance of MPQLinearCUDA layer.
20+
21+
### Fixed
22+
23+
- Fixed various errors in DiodeMix update function.
24+
825
## [0.2.3] - 2024/05/01
926

1027
### Updated

bitorch_engine/layers/qlinear/nbit/cuda/mbwq_layer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,8 @@ def backward(ctx: torch.autograd.function.BackwardCFunction,
109109
grad_input = output_gradient.mm(weights.t()) # (m, n)*(n, k) = (m, k)
110110
#======================================================================================================#
111111

112-
# (n, m) * (m, k) = (n, k)
113112
if qweight.requires_grad: # This additional check is required by peft training.
114-
qweight.privileged_grad = output_gradient.t().mm(input).t() # (k, n)
113+
qweight.privileged_grad = input.t().mm(output_gradient) # (k, m) * (m, n) = (k, n)
115114

116115
grad_input = unflatten_x(grad_input, shape)
117116

bitorch_engine/layers/qlinear/nbit/cuda/mbwq_linear_cuda_kernel.cu

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,6 @@ torch::Tensor mbwq_linear_q4_forward_cuda(
749749
int bits
750750
){
751751
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
752-
cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle();
753752

754753
TORCH_CHECK(x.dtype() == torch::kHalf);
755754
TORCH_CHECK(x.size(1) == qweight.size(0) * (32 / bits));
@@ -770,16 +769,8 @@ torch::Tensor mbwq_linear_q4_forward_cuda(
770769
group_size,
771770
bits,
772771
q_perm);
773-
774-
const half alpha = __float2half(1.0f);
775-
const half beta = __float2half(0.0f);
776-
cublasHgemm(cublas_handle,
777-
CUBLAS_OP_N,
778-
CUBLAS_OP_N,
779-
size_n, size_m, size_k,
780-
&alpha, reinterpret_cast<half *>(fp_w.data_ptr()), size_n,
781-
reinterpret_cast<half *>(x.data_ptr()), size_k,
782-
&beta, reinterpret_cast<half *>(out.data_ptr()), size_n);
772+
// indirectly use cublas through torch matmul api
773+
out = torch::matmul(x, fp_w.to(option_output));
783774

784775
}else{
785776

@@ -943,7 +934,6 @@ torch::Tensor mbwq_linear_exl2_forward_cuda(
943934
bool use_cublas
944935
){
945936
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
946-
cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle();
947937
TORCH_CHECK(x.dtype() == torch::kHalf);
948938

949939
int size_m = x.size(0); // m
@@ -963,15 +953,8 @@ torch::Tensor mbwq_linear_exl2_forward_cuda(
963953
qgroup_map,
964954
rows);
965955

966-
const half alpha = __float2half(1.0f);
967-
const half beta = __float2half(0.0f);
968-
cublasHgemm(cublas_handle,
969-
CUBLAS_OP_N,
970-
CUBLAS_OP_N,
971-
size_n, size_m, size_k,
972-
&alpha, reinterpret_cast<half *>(fp_w.data_ptr()), size_n,
973-
reinterpret_cast<half *>(x.data_ptr()), size_k,
974-
&beta, reinterpret_cast<half *>(out.data_ptr()), size_n);
956+
// indirectly use cublas through torch matmul api
957+
out = torch::matmul(x, fp_w.to(option_output));
975958

976959
}else{
977960
int rows_8 = rows[0];

bitorch_engine/layers/qlinear/nbit/cuda/mpq_layer.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from bitorch_engine.layers.qlinear.nbit import MPQLinearBase
77
from bitorch_engine.utils.safe_import import import_extension
88
from bitorch_engine.utils.model_helper import flatten_x, unflatten_x
9+
from bitorch_engine.layers.qlinear.nbit.cuda.utils import unpack_qweight
910

1011

1112
q_linear_cuda = import_extension("q_linear_cuda")
@@ -47,19 +48,33 @@ def forward(ctx, x: torch.Tensor, qweight: torch.Tensor, a_bit: int, w_bit: int,
4748
Returns:
4849
torch.Tensor: The result of the quantized linear operation.
4950
"""
50-
x, shape = flatten_x(x)
51-
output = q_linear_cuda.mpq_forward(x, qweight, scales, zeros, g_idx, a_bit, w_bit, asym)
52-
if is_training:
51+
def setup_qweight():
5352
qweight.scales = scales
5453
qweight.zeros = zeros
5554
qweight.g_idx = g_idx
5655
qweight.w_bit = w_bit
57-
qweight.privileged_grad = privileged_grad
5856
qweight.asym = asym
5957
qweight.layer_type = 1
58+
59+
x, original_shape = flatten_x(x)
60+
61+
if x.size(0) > 32: # use pytorch api
62+
setup_qweight()
63+
# Reconstruct the floating-point weight
64+
fp_weight = unpack_qweight(qweight)
65+
output = torch.matmul(x, fp_weight)
66+
else:
67+
output = q_linear_cuda.mpq_forward(x, qweight, scales, zeros, g_idx, a_bit, w_bit, asym)
68+
69+
if is_training:
70+
qweight.privileged_grad = privileged_grad
71+
if qweight.scales is None:
72+
setup_qweight()
6073
ctx.a_bit = a_bit
6174
ctx.save_for_backward(x, qweight)
62-
output = unflatten_x(output, shape)
75+
76+
output = unflatten_x(output, original_shape)
77+
6378
return output
6479

6580
@staticmethod
@@ -100,9 +115,8 @@ def backward(ctx: torch.autograd.function.BackwardCFunction,
100115
output_gradient, a_bit, w_bit, asym)
101116
#==================================================================#
102117

103-
# (n, m) * (m, k) = (n, k)
104118
if qweight.requires_grad: # This additional check is required by peft training.
105-
qweight.privileged_grad = output_gradient.t().mm(input).t() # (k, n)
119+
qweight.privileged_grad = input.t().mm(output_gradient) # (k, m) * (m, n) = (k, n)
106120

107121
grad_input = unflatten_x(grad_input, shape)
108122

bitorch_engine/layers/qlinear/nbit/cuda/utils.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def unpack_qweight(qweight: MPQWeightParameter) -> torch.Tensor:
6969
return weights
7070

7171

72-
def pack_fp_weight(weight: torch.Tensor, qweight: MPQWeightParameter) -> torch.Tensor:
72+
def pack_fp_weight(weight: torch.Tensor, qweight: MPQWeightParameter, unpacked_zeros: torch.Tensor = None) -> torch.Tensor:
7373
"""Packs the fp16 weight into a quantized weight format using the attributes defined in the QweightParameter.
7474
7575
This function handles three main scenarios:
@@ -100,8 +100,22 @@ def pack_fp_weight(weight: torch.Tensor, qweight: MPQWeightParameter) -> torch.T
100100

101101
# Process based on layer_type and existence of q_perm for quantization
102102
if layer_type == 1 or (layer_type == 2 and qweight.q_group_map is None): # MPQLinear or MBWQLinear-q4
103-
if asym:
104-
intweight = torch.round(weight / scales[g_idx] + zeros[g_idx]).to(torch.int32).clamp(0, 2**w_bit-1)
103+
if asym: # this if-branch is for classical GPTQ-style models
104+
if unpacked_zeros is not None:
105+
zeros = unpacked_zeros
106+
elif zeros.dtype == torch.int32:
107+
wf = torch.tensor(list(range(0, 32, w_bit)), dtype=torch.int32,
108+
device=qweight.device).unsqueeze(0)
109+
zeros_unpack = torch.bitwise_right_shift(
110+
torch.unsqueeze(zeros, 2).expand(-1, -1, 32 // w_bit),
111+
wf.unsqueeze(0)).to(torch.int16 if w_bit == 8 else torch.int8)
112+
torch.bitwise_and(zeros_unpack, (2 ** w_bit) - 1, out=zeros_unpack)
113+
zeros_unpack = zeros_unpack + 1
114+
zeros = zeros_unpack.reshape(-1, qweight.size(-1))
115+
else:
116+
raise ValueError(f"Error: Got invalid dtype of qweight.zeros while packing fp weight.")
117+
118+
intweight = torch.round(weight / scales[g_idx.long()] + zeros[g_idx.long()]).to(torch.int32).clamp(0, 2**w_bit-1)
105119
else:
106120
if g_idx is None:
107121
# Adjust scales and zeros for symmetric quantization without group index
@@ -114,8 +128,7 @@ def pack_fp_weight(weight: torch.Tensor, qweight: MPQWeightParameter) -> torch.T
114128
intweight = torch.round((weight + zeros) / scales).to(torch.int32).clamp(0, 2 ** w_bit - 1)
115129
else:
116130
# Calculate integer weights for symmetric quantization with group index
117-
# TODO: recalculate scales and zeros?
118-
intweight = torch.round((weight + zeros[g_idx]) / scales[g_idx]).to(torch.int32).clamp(0, 2**w_bit-1)
131+
intweight = torch.round((weight + zeros[g_idx.long()]) / scales[g_idx.long()]).to(torch.int32).clamp(0, 2**w_bit-1)
119132

120133
# Perform parallel bitpacking
121134
wf = torch.tensor(list(range(0, 32, w_bit)), dtype=torch.int32, device=qweight.device).unsqueeze(0)
@@ -128,8 +141,8 @@ def pack_fp_weight(weight: torch.Tensor, qweight: MPQWeightParameter) -> torch.T
128141
dtype=torch.int32
129142
)
130143
else:
131-
# TODO: Placeholder for mixed-bit-width quantization method
132-
raise NotImplementedError("Error: pack_fp_weight for MBWQLinear using mixed-bit-width not supported yet.")
144+
# TODO: Placeholder for channel-mix quantization method
145+
raise NotImplementedError("Error: pack_fp_weight for MBWQLinear using channel-mix quantization not supported yet.")
133146

134147
return intweight.to(torch.int32)
135148

bitorch_engine/layers/qlinear/nbit/layer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,8 @@ def init_gptq(self) -> None:
403403
"""
404404
self.register_buffer('qzeros', torch.zeros((math.ceil(self.in_channels / self.group_size),
405405
self.out_channels // 32 * self.w_bit), dtype=torch.int32))
406-
self.scales = torch.ones((math.ceil(self.in_channels / self.group_size), self.out_channels), dtype=self.dtype)
406+
self.register_buffer('scales', torch.ones((math.ceil(self.in_channels / self.group_size),
407+
self.out_channels), dtype=self.dtype))
407408
self.asym = True
408409

409410
def init_gba(self) -> None:

bitorch_engine/utils/model_helper.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55
import torch.nn.functional as F
6-
from bitorch_engine.utils.quant_operators import nv_tensor_quant, gptq_stype_unpacking
6+
from bitorch_engine.utils.quant_operators import nv_tensor_quant, gptq_style_unpacking, gptq_style_zeros_packing
77
from bitorch_engine.functions.cuda import tensor_to_packed_uint8, unpack_uint8_tensor
88

99

@@ -327,6 +327,39 @@ def init_weight(weight: torch.Tensor, cls: Type[torch.nn.Parameter]=torch.nn.Par
327327
return weight, scale_w
328328

329329

330+
def update_zeros(qweight, w, norm_grad, step_size, z_unpacked=None):
331+
"""
332+
Updates the zeros attribute of the qweight object based on its layer type.
333+
334+
Args:
335+
qweight: An object containing quantization parameters, including the zeros attribute.
336+
w: Weight tensor.
337+
norm_grad: Normalized gradient tensor.
338+
step_size: Step size for updating zeros.
339+
z_unpacked: Optional unpacked zeros tensor for specific layer types.
340+
"""
341+
if qweight.layer_type == 2: # MBWQ-layer
342+
q_perm = qweight.q_perm.unsqueeze(1).repeat(1, w.size(1)).long()
343+
zeros_grad = torch.gather(norm_grad, dim=0, index=q_perm)
344+
qweight.zeros.add_(
345+
step_size * zeros_grad.view(-1, w.size(0) // qweight.scales.size(0), qweight.scales.size(-1)).mean(1)
346+
)
347+
del zeros_grad
348+
elif qweight.layer_type == 1 and qweight.g_idx is not None: # MPQ-layer & GPTQ
349+
zeros_unpack = z_unpacked[qweight.g_idx.long()]
350+
zeros_unpack.add_(step_size * norm_grad)
351+
352+
g_idx = qweight.g_idx.long()
353+
perm = torch.argsort(g_idx, dim=0)
354+
zeros = zeros_unpack[perm, :].view(-1, w.size(0) // qweight.scales.size(0), qweight.scales.size(-1)).mean(1)
355+
356+
# pack to qzeros
357+
qweight.zeros = gptq_style_zeros_packing(zeros, qweight.w_bit, zeros.size(-1), qweight.group_size)
358+
else:
359+
raise NotImplementedError(
360+
"qweight.layer_type: '{}' has not been supported yet.".format(str(qweight.layer_type)))
361+
362+
330363
def qweight_update_fn(qweight: torch.nn.Parameter, exp_avg_s: torch.Tensor=None, exp_avg_l: torch.Tensor=None,
331364
step: torch.Tensor=None, lr:float=1e-4, weight_decay:float=0.0, beta1:float=0.99,
332365
beta2:float=0.9999, eps: float = 1e-6, dtype=torch.half, correct_bias=None, projector=None,
@@ -452,7 +485,9 @@ def qweight_update_fn(qweight: torch.nn.Parameter, exp_avg_s: torch.Tensor=None,
452485
elif isinstance(qweight, MPQWeightParameter):
453486

454487
# unpack qweight
455-
w = gptq_stype_unpacking(qweight).to(dtype)
488+
w, z_unpacked = gptq_style_unpacking(qweight)
489+
w = w.to(dtype)
490+
z_unpacked = z_unpacked.to(dtype)
456491

457492
# Decay the first and second moment running average coefficient
458493
# In-place operations to update the averages at the same time
@@ -475,11 +510,19 @@ def qweight_update_fn(qweight: torch.nn.Parameter, exp_avg_s: torch.Tensor=None,
475510

476511
w.add_(norm_grad, alpha=-step_size)
477512

478-
if weight_decay > 0.0:
479-
w.add_(w, alpha=(-lr * weight_decay))
513+
# ===== update zeros ===== #
514+
# We are not performing the gradient update for 'zeros' in the conventional way.
515+
# Instead, we are making a special handling here because, although 'zeros' is of fp data type,
516+
# in our optimization scenario, it is tied to the updates of 'qweight'.
517+
# Moreover, 'zeros' is not always updated but interacts with 'qweight' at a relatively sparse frequency.
518+
# If we were to update 'zeros' as a regular fp-parameter, it might not allow us the flexibility
519+
# to design these interactions conveniently.
520+
# Considering this is a beta version, future updates and adjustments might be possible.
521+
if step % 5 == 0:
522+
update_zeros(qweight, w, norm_grad, step_size, z_unpacked)
480523

481524
# pack fp weight back to Q-weight and update qweight data
482-
qweight.data = pack_fp_weight(w, qweight)
525+
qweight.data = pack_fp_weight(w, qweight, z_unpacked)
483526

484527
# manually empty cuda cache.
485528
del w

bitorch_engine/utils/quant_operators.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Tuple
2+
import math
23

34
import torch
45

@@ -306,7 +307,7 @@ def q4_quantization(input: torch.Tensor, scale_a: torch.Tensor=None, eps: torch.
306307
return (input / scale_a).round().clamp(Qn, Qp)
307308

308309

309-
def gptq_stype_unpacking(qweight) -> torch.Tensor:
310+
def gptq_style_unpacking(qweight) -> Tuple[torch.Tensor, torch.Tensor]:
310311
"""
311312
Reconstructs the fp16 weight tensor from the input quantized weight parameter in GPTQ style.
312313
@@ -341,4 +342,27 @@ def gptq_stype_unpacking(qweight) -> torch.Tensor:
341342
else:
342343
weights = weight * qweight.scales[qweight.g_idx.long()] - qweight.zeros[qweight.g_idx.long()]
343344

344-
return weights
345+
return weights, zeros
346+
347+
348+
def gptq_style_zeros_packing(zeros: torch.Tensor, w_bit: int, out_features: int, group_size: int) -> torch.Tensor:
349+
"""
350+
Packs the zeros tensor in GPTQ style for efficient storage and computation.
351+
352+
Args:
353+
zeros (torch.Tensor): Input tensor containing zeros.
354+
w_bit (int): Number of bits for weight quantization.
355+
out_features (int): Number of output features.
356+
group_size (int): Size of the group for packing.
357+
358+
Returns:
359+
torch.Tensor: Packed tensor with reduced storage.
360+
"""
361+
362+
zeros = zeros.reshape(zeros.shape[0], math.ceil(out_features // 32 * w_bit), 32//w_bit).to(torch.int32)
363+
zeros_pack = zeros - 1
364+
wf = torch.arange(0, 32, w_bit, device=zeros.device, dtype=torch.int32)
365+
zeros_pack = torch.bitwise_and(zeros_pack, (2 ** w_bit) - 1)
366+
zeros_pack = torch.bitwise_left_shift(zeros_pack.to(torch.int32), wf.unsqueeze(0).unsqueeze(1))
367+
zeros_pack = zeros_pack.sum(dim=-1).to(torch.int32)
368+
return zeros_pack

docker/build_scripts/install_modified_pytorch.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ fi
2424
if [ "${from_image}" == "pytorch/pytorch:2.2.0-cuda12.1-cudnn8-devel" ]; then
2525
gdrive_id="1LjFNImboq8QeFSompMS2gPjBRYtP2Dsz"
2626
file="torch-2.2.2-cp310-cp310-linux_x86_64.whl"
27-
checksum="2a5953dab7be6c1640112e38ae7519ad88180d9fa79faab6c86dbee6b1cc210e"
27+
checksum="bcc0ba7f121ee2f42ed0a59f01d4e3d70f82a8981be0be25c5e0fe0635a54b2d"
2828
fi
2929
#if [ "${from_image}" == "pytorch/pytorch:X.X.X-cudaXX.X-cudnn8-devel" ]; then
3030
# gdrive_id="xxx"

version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.2.0
1+
0.2.4

0 commit comments

Comments
 (0)