Skip to content

Commit 5d0c852

Browse files
Adds get_quantization_layer_structure hooks for GPTQ (#2462)
* Adds get_quantization_layer_structure hooks * cleanup * format
1 parent 9b25eee commit 5d0c852

File tree

8 files changed

+137
-0
lines changed

8 files changed

+137
-0
lines changed

keras_hub/src/models/causal_lm.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,3 +429,25 @@ def _post_quantize(self, mode, **kwargs):
429429
super()._post_quantize(mode, **kwargs)
430430
# Reset the compiled generate function.
431431
self.generate_function = None
432+
433+
def get_quantization_layer_structure(self, mode):
434+
if mode != "gptq":
435+
return None
436+
437+
backbone = self.backbone
438+
# Check for standard backbone structure.
439+
if not hasattr(backbone, "transformer_layers"):
440+
return None
441+
442+
# Check for embedding.
443+
embedding = getattr(backbone, "token_embedding", None)
444+
if embedding is None:
445+
embedding = getattr(backbone, "embedding", None)
446+
447+
if embedding is None:
448+
return None
449+
450+
return {
451+
"pre_block_layers": [embedding],
452+
"sequential_blocks": backbone.transformer_layers,
453+
}

keras_hub/src/models/gemma/gemma_causal_lm.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,3 +431,19 @@ def default_layer_intercept_fn(x, unused_i):
431431
)
432432
per_token_loss = per_token_loss_fn(target_ids, logits)
433433
return per_token_loss
434+
435+
def get_quantization_layer_structure(self, mode):
436+
if mode != "gptq":
437+
return None
438+
439+
# Wrap embedding + scaling
440+
backbone = self.backbone
441+
inputs = keras.Input(shape=(None,), dtype="int32")
442+
x = backbone.token_embedding(inputs)
443+
x = x * ops.cast(ops.sqrt(backbone.hidden_dim), x.dtype)
444+
pre_processor = keras.Model(inputs=inputs, outputs=x)
445+
446+
return {
447+
"pre_block_layers": [pre_processor],
448+
"sequential_blocks": backbone.transformer_layers,
449+
}

keras_hub/src/models/gemma/gemma_causal_lm_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,3 +295,17 @@ def layer_intercept_fn_for_testing(x, i):
295295
# Assert shapes for info exfiltrated into the parent context.
296296
self.assertEqual(ops.shape(embedded_prompts), expected_embedded_shape)
297297
self.assertEqual(ops.shape(scores), expected_score_shape)
298+
299+
def test_get_quantization_layer_structure(self):
300+
causal_lm = GemmaCausalLM(**self.init_kwargs)
301+
structure = causal_lm.get_quantization_layer_structure("gptq")
302+
self.assertIsInstance(structure, dict)
303+
self.assertIn("pre_block_layers", structure)
304+
self.assertIn("sequential_blocks", structure)
305+
self.assertLen(structure["pre_block_layers"], 1)
306+
self.assertIsInstance(structure["pre_block_layers"][0], keras.Model)
307+
self.assertEqual(
308+
structure["sequential_blocks"], self.backbone.transformer_layers
309+
)
310+
311+
self.assertIsNone(causal_lm.get_quantization_layer_structure("int8"))

keras_hub/src/models/gpt2/gpt2_causal_lm.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,3 +420,20 @@ def default_layer_intercept_fn(x, unused_i):
420420
)
421421
per_token_loss = per_token_loss_fn(target_ids, logits)
422422
return per_token_loss
423+
424+
def get_quantization_layer_structure(self, mode):
425+
if mode != "gptq":
426+
return None
427+
428+
backbone = self.backbone
429+
token_ids = keras.Input(shape=(None,), dtype="int32")
430+
tokens = backbone.token_embedding(token_ids)
431+
positions = backbone.position_embedding(tokens)
432+
x = backbone.embeddings_add((tokens, positions))
433+
x = backbone.embeddings_dropout(x)
434+
pre_processor = keras.Model(inputs=token_ids, outputs=x)
435+
436+
return {
437+
"pre_block_layers": [pre_processor],
438+
"sequential_blocks": backbone.transformer_layers,
439+
}

keras_hub/src/models/gpt2/gpt2_causal_lm_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from unittest.mock import patch
22

3+
import keras
34
import pytest
45
from keras import ops
56

@@ -199,3 +200,17 @@ def layer_intercept_fn_for_testing(x, i):
199200
# Assert shapes for info exfiltrated into the parent context.
200201
self.assertEqual(ops.shape(embedded_prompts), expected_embedded_shape)
201202
self.assertEqual(ops.shape(scores), expected_score_shape)
203+
204+
def test_get_quantization_layer_structure(self):
205+
causal_lm = GPT2CausalLM(**self.init_kwargs)
206+
structure = causal_lm.get_quantization_layer_structure("gptq")
207+
self.assertIsInstance(structure, dict)
208+
self.assertIn("pre_block_layers", structure)
209+
self.assertIn("sequential_blocks", structure)
210+
self.assertLen(structure["pre_block_layers"], 1)
211+
self.assertIsInstance(structure["pre_block_layers"][0], keras.Model)
212+
self.assertEqual(
213+
structure["sequential_blocks"], self.backbone.transformer_layers
214+
)
215+
216+
self.assertIsNone(causal_lm.get_quantization_layer_structure("int8"))

keras_hub/src/models/masked_lm.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,25 @@ def compile(
8484
weighted_metrics=weighted_metrics,
8585
**kwargs,
8686
)
87+
88+
def get_quantization_layer_structure(self, mode):
89+
if mode != "gptq":
90+
return None
91+
92+
backbone = self.backbone
93+
# Check for standard backbone structure.
94+
if not hasattr(backbone, "transformer_layers"):
95+
return None
96+
97+
# Check for embedding.
98+
embedding = getattr(backbone, "token_embedding", None)
99+
if embedding is None:
100+
embedding = getattr(backbone, "embedding", None)
101+
102+
if embedding is None:
103+
return None
104+
105+
return {
106+
"pre_block_layers": [embedding],
107+
"sequential_blocks": backbone.transformer_layers,
108+
}

keras_hub/src/models/mistral/mistral_causal_lm_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,18 @@ def layer_intercept_fn_for_testing(x, i):
199199
# Assert shapes for info exfiltrated into the parent context.
200200
self.assertEqual(ops.shape(embedded_prompts), expected_embedded_shape)
201201
self.assertEqual(ops.shape(scores), expected_score_shape)
202+
203+
def test_get_quantization_layer_structure(self):
204+
causal_lm = MistralCausalLM(**self.init_kwargs)
205+
structure = causal_lm.get_quantization_layer_structure("gptq")
206+
self.assertIsInstance(structure, dict)
207+
self.assertIn("pre_block_layers", structure)
208+
self.assertIn("sequential_blocks", structure)
209+
self.assertEqual(
210+
structure["pre_block_layers"], [self.backbone.token_embedding]
211+
)
212+
self.assertEqual(
213+
structure["sequential_blocks"], self.backbone.transformer_layers
214+
)
215+
216+
self.assertIsNone(causal_lm.get_quantization_layer_structure("int8"))

keras_hub/src/models/phi3/phi3_causal_lm_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,19 @@ def test_all_presets(self):
115115
preset=preset,
116116
input_data=self.input_data,
117117
)
118+
119+
def test_get_quantization_layer_structure(self):
120+
causal_lm = Phi3CausalLM(**self.init_kwargs)
121+
structure = causal_lm.get_quantization_layer_structure("gptq")
122+
self.assertIsInstance(structure, dict)
123+
self.assertIn("pre_block_layers", structure)
124+
self.assertIn("sequential_blocks", structure)
125+
self.assertEqual(
126+
structure["pre_block_layers"],
127+
[self.backbone.token_embedding],
128+
)
129+
self.assertEqual(
130+
structure["sequential_blocks"], self.backbone.transformer_layers
131+
)
132+
133+
self.assertIsNone(causal_lm.get_quantization_layer_structure("int8"))

0 commit comments

Comments
 (0)