Skip to content

Commit 4ac758e

Browse files
removed numerics like +,-,* etc and used keras.ops
1 parent 674a7bd commit 4ac758e

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

keras/src/quantizers/gptq.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -100,16 +100,21 @@ def update_hessian_with_batch(self, inp):
100100
"match input features ({inp.shape[-1]})."
101101
)
102102

103-
current_H = 2 * ops.matmul(ops.transpose(inp), inp)
103+
current_H = ops.multiply(2, ops.matmul(ops.transpose(inp), inp))
104104

105105
if self.nsamples == 0:
106106
self.H = current_H
107107
else:
108-
self.H = self.H * (self.nsamples / (self.nsamples + inp.shape[0]))
109-
self.H += current_H * (
110-
inp.shape[0] / (self.nsamples + inp.shape[0])
111-
)
112-
self.nsamples += inp.shape[0]
108+
total_samples = ops.add(self.nsamples, inp.shape[0])
109+
old_H_weight = ops.divide(self.nsamples, total_samples)
110+
current_H_weight = ops.divide(inp.shape[0], total_samples)
111+
112+
# Update the accumulated Hessian
113+
term1 = ops.multiply(self.H, old_H_weight)
114+
term2 = ops.multiply(current_H, current_H_weight)
115+
self.H = ops.add(term1, term2)
116+
117+
self.nsamples = ops.add(self.nsamples, inp.shape[0])
113118

114119
def quantize_and_correct_block(
115120
self, blocksize=128, percdamp=0.01, group_size=-1, actorder=False
@@ -172,12 +177,14 @@ def quantize_and_correct_block(
172177
diag_H = ops.diagonal(H)
173178
dead = ops.equal(diag_H, 0.0)
174179
diag_H = ops.where(dead, 1.0, diag_H)
175-
H = H + ops.diag(ops.where(dead, 1.0, ops.zeros_like(diag_H)))
180+
H = ops.add(H, ops.diag(ops.where(dead, 1.0, ops.zeros_like(diag_H))))
176181

177182
# Add dampening factor to the Hessian diagonal
178-
damp = percdamp * ops.mean(diag_H)
179-
diag_H = diag_H + damp
180-
H = (H - ops.diag(ops.diagonal(H))) + ops.diag(diag_H)
183+
damp = ops.multiply(percdamp, ops.mean(diag_H))
184+
diag_H = ops.add(diag_H, damp)
185+
H = ops.add(
186+
ops.subtract(H, ops.diag(ops.diagonal(H))), ops.diag(diag_H)
187+
)
181188

182189
# Compute the inverse Hessian, which is used for error correction
183190
Hinv = ops.linalg.inv(H)
@@ -217,7 +224,7 @@ def quantize_and_correct_block(
217224
)[:, 0]
218225

219226
Q1 = ops.slice_update(Q1, (0, i), ops.expand_dims(q, axis=1))
220-
err = (w - q) / d
227+
err = ops.divide(ops.subtract(w, q), d)
221228
Err1 = ops.slice_update(
222229
Err1, (0, i), ops.expand_dims(err, axis=1)
223230
)
@@ -230,7 +237,7 @@ def quantize_and_correct_block(
230237

231238
# Efficiently update the remaining part of the W1 tensor.
232239
slice_to_update = W1[:, i + 1 :]
233-
updated_slice = slice_to_update - update
240+
updated_slice = ops.subtract(slice_to_update, update)
234241
W1 = ops.slice_update(W1, (0, i + 1), updated_slice)
235242

236243
# Update the full quantized matrix Q with the processed block
@@ -239,7 +246,7 @@ def quantize_and_correct_block(
239246
if i2 < self.rows:
240247
update_total = ops.matmul(Err1, Hinv[i1:i2, i2:])
241248
W = ops.concatenate(
242-
[W[:, :i2], W[:, i2:] - update_total], axis=1
249+
[W[:, :i2], ops.subtract(W[:, i2:], update_total)], axis=1
243250
)
244251

245252
if actorder:

0 commit comments

Comments
 (0)