Skip to content

Commit 7a90148

Browse files
Fixed gptq algo for inline weights update
1 parent 24b7501 commit 7a90148

File tree

2 files changed

+9
-14
lines changed

2 files changed

+9
-14
lines changed

keras/src/models/model_test.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,8 +1391,3 @@ def test_quantize_gptq_with_data_gen(self):
13911391
nsamples=16, seqlen=128, vocab_size=1000
13921392
)
13931393
_run_gptq_test_on_dataset(self, generator_dataset)
1394-
1395-
@pytest.mark.slow
1396-
def test_quantize_gptq_with_wikitext2(self):
1397-
"""Tests GPTQ with the 'wikitext2' dataset identifier."""
1398-
_run_gptq_test_on_dataset(self, "wikitext2")

keras/src/quantizers/gptq.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -135,23 +135,23 @@ def quantize_and_correct_block(
135135
self.quantizer.maxq,
136136
)[:, 0]
137137

138-
Q1 = ops.concatenate(
139-
[Q1[:, :i], ops.expand_dims(q, 1), Q1[:, i + 1 :]], axis=1
140-
)
138+
Q1 = ops.slice_update(Q1, (0, i), ops.expand_dims(q, axis=1))
141139
err = (w - q) / d
142-
Err1 = ops.concatenate(
143-
[Err1[:, :i], ops.expand_dims(err, 1), Err1[:, i + 1 :]],
144-
axis=1,
140+
Err1 = ops.slice_update(
141+
Err1, (0, i), ops.expand_dims(err, axis=1)
145142
)
146143

147144
if i < count - 1:
148145
update = ops.matmul(
149146
ops.expand_dims(err, 1),
150147
ops.expand_dims(Hinv1[i, i + 1 :], 0),
151148
)
152-
W1 = ops.concatenate(
153-
[W1[:, : i + 1], W1[:, i + 1 :] - update], axis=1
154-
)
149+
150+
# Efficiently update the remaining part of the W1 tensor.
151+
# This is equivalent to W1[:, i + 1 :] -= update
152+
slice_to_update = W1[:, i + 1 :]
153+
updated_slice = slice_to_update - update
154+
W1 = ops.slice_update(W1, (0, i + 1), updated_slice)
155155

156156
Q = ops.concatenate([Q[:, :i1], Q1, Q[:, i2:]], axis=1)
157157

0 commit comments

Comments
 (0)