@@ -61,6 +61,27 @@ def __init__(self, layer):
61
61
self .H = ops .zeros ((self .rows , self .rows ), dtype = "float32" )
62
62
63
63
def update_hessian_with_batch (self , inp ):
64
+ """
65
+ Updates the running average of the Hessian matrix with a new batch.
66
+
67
+ This method computes the Hessian matrix for a given batch of input
68
+ activations and updates the accumulated Hessian (`self.H`) using a
69
+ numerically stable running average. This allows the Hessian to be
70
+ computed over a large dataset without loading all samples into memory
71
+ at once.
72
+
73
+ The input tensor is first reshaped into a 2D matrix [num_samples,
74
+ num_features] before the Hessian is calculated.
75
+
76
+ Args:
77
+ inp: A 2D or higher-dimensional tensor of input activations from a
78
+ calibration batch.
79
+
80
+ Raises:
81
+ ValueError: If the feature dimension of the input tensor `inp` does
82
+ not match the dimensions of the pre-initialized Hessian matrix
83
+ `self.H`.
84
+ """
64
85
if len (inp .shape ) > 2 :
65
86
inp = ops .reshape (inp , (- 1 , inp .shape [- 1 ]))
66
87
inp = ops .cast (inp , "float32" )
@@ -85,6 +106,51 @@ def update_hessian_with_batch(self, inp):
85
106
def quantize_and_correct_block (
86
107
self , blocksize = 128 , percdamp = 0.01 , groupsize = - 1 , actorder = False
87
108
):
109
+ """
110
+ Performs GPTQ quantization and correction on the layer's weights.
111
+
112
+ This method implements the core logic of the "Optimal Brain Quant"
113
+ (OBQ) method, as applied by GPTQ, to quantize the weights of a single
114
+ layer. It iteratively quantizes blocks of weights and corrects for the
115
+ quantization error by updating the remaining weights.
116
+
117
+ The algorithm follows these main steps:
118
+ 1. **Initialization**: It optionally reorders the weight columns based
119
+ on activation magnitudes (`actorder=True`) to protect more salient
120
+ weights.
121
+ 2. **Hessian Modification**: The Hessian matrix `H`, pre-computed from
122
+ calibration data, is dampened to ensure its invertibility and
123
+ stability.
124
+ 3. **Iterative Quantization**: The function iterates through the
125
+ weight columns in blocks (`blocksize`). In each iteration, it:
126
+ a. Quantizes one column (`w`).
127
+ b. Calculates the quantization error (`err`).
128
+ c. Updates the remaining weights in the *current* block by
129
+ distributing the error, using the inverse Hessian (`Hinv`).
130
+ 4. **Block-wise Correction**: After a block is quantized, the total
131
+ error from that block is propagated to the *next* block of weights
132
+ to be processed.
133
+ 5. **Finalization**: The quantized weights (`Q`) are reordered back if
134
+ `actorder` was used, and the layer's weights are updated.
135
+
136
+ This implementation is based on the official GPTQ paper and repository.
137
+ For more details, see:
138
+ - Paper: https://arxiv.org/abs/2210.17323
139
+ - Original Code: https://github.com/IST-DASLab/gptq
140
+
141
+ Args:
142
+ blocksize (int, optional): The size of the weight block to process
143
+ at a time. Defaults to 128.
144
+ percdamp (float, optional): The percentage of dampening to add the
145
+ Hessian's diagonal. A value of 0.01 is recommended.
146
+ Defaults to 0.01.
147
+ groupsize (int, optional): The number of weights that share the
148
+ same quantization parameters (scale and zero-point).
149
+ A value of -1 indicates per-channel quantization.
150
+ actorder (bool, optional): If True, reorders weight columns based
151
+ on their activation's second-order information.
152
+ """
153
+
88
154
W = ops .transpose (ops .cast (self .layer .kernel , "float32" ))
89
155
H = ops .cast (self .H , "float32" )
90
156
@@ -94,26 +160,32 @@ def quantize_and_correct_block(
94
160
H = ops .take (ops .take (H , perm , axis = 0 ), perm , axis = 1 )
95
161
invperm = ops .argsort (perm )
96
162
163
+ # Dampen the Hessian for Stability
97
164
diag_H = ops .diagonal (H )
98
165
dead = ops .equal (diag_H , 0.0 )
99
166
diag_H = ops .where (dead , 1.0 , diag_H )
100
167
H = H + ops .diag (ops .where (dead , 1.0 , ops .zeros_like (diag_H )))
168
+
169
+ # Add dampening factor to the Hessian diagonal
101
170
damp = percdamp * ops .mean (diag_H )
102
171
diag_H = diag_H + damp
103
172
H = (H - ops .diag (ops .diagonal (H ))) + ops .diag (diag_H )
104
173
174
+ # Compute the inverse Hessian, which is used for error correction
105
175
Hinv = ops .linalg .inv (H )
106
176
Q = ops .zeros_like (W )
107
177
108
178
for i1 in range (0 , self .rows , blocksize ):
109
179
i2 = min (i1 + blocksize , self .rows )
110
180
count = i2 - i1
111
-
181
+ # Extract the current block of weights and its corresponding
182
+ # Hessian
112
183
W1 = W [:, i1 :i2 ]
113
184
Q1 = ops .zeros_like (W1 )
114
185
Err1 = ops .zeros_like (W1 )
115
186
Hinv1 = Hinv [i1 :i2 , i1 :i2 ]
116
187
188
+ # Process one column at a time within the block
117
189
for i in range (count ):
118
190
w = W1 [:, i ]
119
191
d = Hinv1 [i , i ]
@@ -128,6 +200,7 @@ def quantize_and_correct_block(
128
200
ops .expand_dims (w , 1 ), weight = True
129
201
)
130
202
203
+ # Quantize the current weight column
131
204
q = quantize (
132
205
ops .expand_dims (w , 1 ),
133
206
self .quantizer .scale ,
@@ -148,11 +221,11 @@ def quantize_and_correct_block(
148
221
)
149
222
150
223
# Efficiently update the remaining part of the W1 tensor.
151
- # This is equivalent to W1[:, i + 1 :] -= update
152
224
slice_to_update = W1 [:, i + 1 :]
153
225
updated_slice = slice_to_update - update
154
226
W1 = ops .slice_update (W1 , (0 , i + 1 ), updated_slice )
155
227
228
+ # Update the full quantized matrix Q with the processed block
156
229
Q = ops .concatenate ([Q [:, :i1 ], Q1 , Q [:, i2 :]], axis = 1 )
157
230
158
231
if i2 < self .rows :
@@ -169,6 +242,7 @@ def quantize_and_correct_block(
169
242
if isinstance (self .original_layer , EinsumDense ):
170
243
Q = ops .reshape (Q , self .kernel_shape )
171
244
245
+ # Set the new quantized weights in the original layer
172
246
new_weights = [ops .convert_to_numpy (Q )]
173
247
if self .original_layer .bias is not None :
174
248
new_weights .append (ops .convert_to_numpy (self .original_layer .bias ))
0 commit comments