@@ -100,16 +100,21 @@ def update_hessian_with_batch(self, inp):
100
100
"match input features ({inp.shape[-1]})."
101
101
)
102
102
103
- current_H = 2 * ops .matmul (ops .transpose (inp ), inp )
103
+ current_H = ops . multiply ( 2 , ops .matmul (ops .transpose (inp ), inp ) )
104
104
105
105
if self .nsamples == 0 :
106
106
self .H = current_H
107
107
else :
108
- self .H = self .H * (self .nsamples / (self .nsamples + inp .shape [0 ]))
109
- self .H += current_H * (
110
- inp .shape [0 ] / (self .nsamples + inp .shape [0 ])
111
- )
112
- self .nsamples += inp .shape [0 ]
108
+ total_samples = ops .add (self .nsamples , inp .shape [0 ])
109
+ old_H_weight = ops .divide (self .nsamples , total_samples )
110
+ current_H_weight = ops .divide (inp .shape [0 ], total_samples )
111
+
112
+ # Update the accumulated Hessian
113
+ term1 = ops .multiply (self .H , old_H_weight )
114
+ term2 = ops .multiply (current_H , current_H_weight )
115
+ self .H = ops .add (term1 , term2 )
116
+
117
+ self .nsamples = ops .add (self .nsamples , inp .shape [0 ])
113
118
114
119
def quantize_and_correct_block (
115
120
self , blocksize = 128 , percdamp = 0.01 , group_size = - 1 , actorder = False
@@ -172,12 +177,14 @@ def quantize_and_correct_block(
172
177
diag_H = ops .diagonal (H )
173
178
dead = ops .equal (diag_H , 0.0 )
174
179
diag_H = ops .where (dead , 1.0 , diag_H )
175
- H = H + ops .diag (ops .where (dead , 1.0 , ops .zeros_like (diag_H )))
180
+ H = ops . add ( H , ops .diag (ops .where (dead , 1.0 , ops .zeros_like (diag_H ) )))
176
181
177
182
# Add dampening factor to the Hessian diagonal
178
- damp = percdamp * ops .mean (diag_H )
179
- diag_H = diag_H + damp
180
- H = (H - ops .diag (ops .diagonal (H ))) + ops .diag (diag_H )
183
+ damp = ops .multiply (percdamp , ops .mean (diag_H ))
184
+ diag_H = ops .add (diag_H , damp )
185
+ H = ops .add (
186
+ ops .subtract (H , ops .diag (ops .diagonal (H ))), ops .diag (diag_H )
187
+ )
181
188
182
189
# Compute the inverse Hessian, which is used for error correction
183
190
Hinv = ops .linalg .inv (H )
@@ -217,7 +224,7 @@ def quantize_and_correct_block(
217
224
)[:, 0 ]
218
225
219
226
Q1 = ops .slice_update (Q1 , (0 , i ), ops .expand_dims (q , axis = 1 ))
220
- err = ( w - q ) / d
227
+ err = ops . divide ( ops . subtract ( w , q ), d )
221
228
Err1 = ops .slice_update (
222
229
Err1 , (0 , i ), ops .expand_dims (err , axis = 1 )
223
230
)
@@ -230,7 +237,7 @@ def quantize_and_correct_block(
230
237
231
238
# Efficiently update the remaining part of the W1 tensor.
232
239
slice_to_update = W1 [:, i + 1 :]
233
- updated_slice = slice_to_update - update
240
+ updated_slice = ops . subtract ( slice_to_update , update )
234
241
W1 = ops .slice_update (W1 , (0 , i + 1 ), updated_slice )
235
242
236
243
# Update the full quantized matrix Q with the processed block
@@ -239,7 +246,7 @@ def quantize_and_correct_block(
239
246
if i2 < self .rows :
240
247
update_total = ops .matmul (Err1 , Hinv [i1 :i2 , i2 :])
241
248
W = ops .concatenate (
242
- [W [:, :i2 ], W [:, i2 :] - update_total ], axis = 1
249
+ [W [:, :i2 ], ops . subtract ( W [:, i2 :], update_total ) ], axis = 1
243
250
)
244
251
245
252
if actorder :
0 commit comments