Skip to content

Commit cdc7e29

Browse files
Adds store serialization tests for GPTQ (#21689)
* Add store serialization tests for GPTQ * remove TODO comments * kernels are now in packed form
1 parent 52893b0 commit cdc7e29

File tree

2 files changed

+44
-2
lines changed

2 files changed

+44
-2
lines changed

keras/src/layers/core/dense_test.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -831,7 +831,6 @@ def test_int4_kernel_returns_unpacked_form(self):
831831
def test_legacy_load_own_variables(self):
832832
# In previous versions, `load_own_variables` accepted a store with
833833
# numeric keys.
834-
# TODO(JyotinderSingh): add gptq_store test.
835834
float32_store = {
836835
"0": np.random.random((8, 16)).astype("float32"),
837836
"1": np.random.random((16,)).astype("float32"),
@@ -862,6 +861,18 @@ def test_legacy_load_own_variables(self):
862861
# outputs_grad_amax_history.
863862
"7": np.random.random((1024,)).astype("float32"),
864863
}
864+
gptq_store = {
865+
# bias
866+
"0": np.random.random((16,)).astype("float32"),
867+
# quantized_kernel
868+
"1": np.random.randint(0, 16, size=(8, 8), dtype="uint8"),
869+
# kernel_scale.
870+
"2": np.random.random((16, 1)).astype("float32"),
871+
# kernel_zero
872+
"3": np.random.random((16, 1)).astype("uint8"),
873+
# g_idx
874+
"4": np.random.random((8,)).astype("float32"),
875+
}
865876

866877
# Test float32 layer.
867878
layer = layers.Dense(units=16)
@@ -899,6 +910,16 @@ def test_legacy_load_own_variables(self):
899910
self.assertAllClose(layer.outputs_grad_scale, float8_store["6"])
900911
self.assertAllClose(layer.outputs_grad_amax_history, float8_store["7"])
901912

913+
# Test gptq-quantized layer.
914+
layer = layers.Dense(units=16, dtype="gptq/4/8_from_float32")
915+
layer.build((None, 8))
916+
layer.load_own_variables(gptq_store)
917+
self.assertAllClose(layer.bias, gptq_store["0"])
918+
self.assertAllClose(layer.quantized_kernel, gptq_store["1"])
919+
self.assertAllClose(layer.kernel_scale, gptq_store["2"])
920+
self.assertAllClose(layer.kernel_zero, gptq_store["3"])
921+
self.assertAllClose(layer.g_idx, gptq_store["4"])
922+
902923
def test_int4_gptq_kernel_returns_unpacked_form(self):
903924
"""Test that the `kernel` property returns the unpacked int4 GPTQ
904925
kernel."""

keras/src/layers/core/einsum_dense_test.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1043,7 +1043,6 @@ def test_int4_kernel_returns_unpacked_form(self):
10431043
def test_legacy_load_own_variables(self):
10441044
# In previous versions, `load_own_variables` accepted a store with
10451045
# numeric keys.
1046-
# TODO(JyotinderSingh): add gptq_store test.
10471046
float32_store = {
10481047
"0": np.random.random((3, 8, 32)).astype("float32"),
10491048
"1": np.random.random((32,)).astype("float32"),
@@ -1074,6 +1073,18 @@ def test_legacy_load_own_variables(self):
10741073
# outputs_grad_amax_history.
10751074
"7": np.random.random((1024,)).astype("float32"),
10761075
}
1076+
gptq_store = {
1077+
# bias
1078+
"0": np.random.random((32,)).astype("float32"),
1079+
# quantized_kernel
1080+
"1": np.random.randint(0, 16, size=(16, 24), dtype="uint8"),
1081+
# kernel_scale.
1082+
"2": np.random.random((32, 3)).astype("float32"),
1083+
# kernel_zero
1084+
"3": np.random.random((32, 3)).astype("uint8"),
1085+
# g_idx
1086+
"4": np.random.random((24,)).astype("float32"),
1087+
}
10771088
config = dict(
10781089
equation="ab,bcd->acd",
10791090
output_shape=(8, 32),
@@ -1116,6 +1127,16 @@ def test_legacy_load_own_variables(self):
11161127
self.assertAllClose(layer.outputs_grad_scale, float8_store["6"])
11171128
self.assertAllClose(layer.outputs_grad_amax_history, float8_store["7"])
11181129

1130+
# Test gptq-quantized layer.
1131+
layer = layers.EinsumDense(**config, dtype="gptq/4/8_from_float32")
1132+
layer.build((None, 3))
1133+
layer.load_own_variables(gptq_store)
1134+
self.assertAllClose(layer.bias, gptq_store["0"])
1135+
self.assertAllClose(layer.quantized_kernel, gptq_store["1"])
1136+
self.assertAllClose(layer.kernel_scale, gptq_store["2"])
1137+
self.assertAllClose(layer.kernel_zero, gptq_store["3"])
1138+
self.assertAllClose(layer.g_idx, gptq_store["4"])
1139+
11191140
def test_int4_gptq_kernel_returns_unpacked_form(self):
11201141
"""Test that the `kernel` property returns the unpacked int4 GPTQ
11211142
kernel."""

0 commit comments

Comments
 (0)