Skip to content

Commit efe96c9

Browse files
authored
Add generic export_to_transformers to the base classes (#2346)
* export_to_transformers * Address comments * Fixes AutoTokenizer and AutoModel compatibility * Add tests * Add lm_head * export_to_transformers * Address comments * Fixes AutoTokenizer and AutoModel compatibility * Add tests * Add lm_head * resolve gemma_test * Address comments and add llama model export support * Move llama to a new PR * address nits
1 parent d12bc61 commit efe96c9

File tree

15 files changed

+459
-76
lines changed

15 files changed

+459
-76
lines changed

keras_hub/src/models/backbone.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,3 +277,19 @@ def load_lora_weights(self, filepath):
277277
layer.lora_kernel_a.assign(lora_kernel_a)
278278
layer.lora_kernel_b.assign(lora_kernel_b)
279279
store.close()
280+
281+
def export_to_transformers(self, path):
282+
"""Export the backbone model to HuggingFace Transformers format.
283+
284+
This saves the backbone's configuration and weights in a format
285+
compatible with HuggingFace Transformers. For unsupported model
286+
architectures, a ValueError is raised.
287+
288+
Args:
289+
path: str. Path to save the exported model.
290+
"""
291+
from keras_hub.src.utils.transformers.export.hf_exporter import (
292+
export_backbone,
293+
)
294+
295+
export_backbone(self, path)

keras_hub/src/models/backbone_test.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from keras_hub.src.models.backbone import Backbone
77
from keras_hub.src.models.bert.bert_backbone import BertBackbone
8+
from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone
89
from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone
910
from keras_hub.src.tests.test_case import TestCase
1011
from keras_hub.src.utils.preset_utils import CONFIG_FILE
@@ -105,3 +106,40 @@ def test_save_to_preset(self):
105106
ref_out = backbone(data)
106107
new_out = restored_backbone(data)
107108
self.assertAllClose(ref_out, new_out)
109+
110+
def test_export_supported_model(self):
111+
backbone_config = {
112+
"vocabulary_size": 1000,
113+
"num_layers": 2,
114+
"num_query_heads": 4,
115+
"num_key_value_heads": 1,
116+
"hidden_dim": 512,
117+
"intermediate_dim": 1024,
118+
"head_dim": 128,
119+
}
120+
backbone = GemmaBackbone(**backbone_config)
121+
export_path = os.path.join(self.get_temp_dir(), "export_backbone")
122+
backbone.export_to_transformers(export_path)
123+
# Basic check: config file exists
124+
self.assertTrue(
125+
os.path.exists(os.path.join(export_path, "config.json"))
126+
)
127+
128+
def test_export_unsupported_model(self):
129+
backbone_config = {
130+
"vocabulary_size": 1000,
131+
"num_layers": 2,
132+
"num_query_heads": 4,
133+
"num_key_value_heads": 1,
134+
"hidden_dim": 512,
135+
"intermediate_dim": 1024,
136+
"head_dim": 128,
137+
}
138+
139+
class UnsupportedBackbone(GemmaBackbone):
140+
pass
141+
142+
backbone = UnsupportedBackbone(**backbone_config)
143+
export_path = os.path.join(self.get_temp_dir(), "unsupported")
144+
with self.assertRaises(ValueError):
145+
backbone.export_to_transformers(export_path)

keras_hub/src/models/causal_lm.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,3 +392,24 @@ def postprocess(x):
392392
outputs = [postprocess(x) for x in outputs]
393393

394394
return self._normalize_generate_outputs(outputs, input_is_scalar)
395+
396+
def export_to_transformers(self, path):
397+
"""Export the full CausalLM model to HuggingFace Transformers format.
398+
399+
This exports the trainable model, tokenizer, and configurations in a
400+
format compatible with HuggingFace Transformers. For unsupported model
401+
architectures, a ValueError is raised.
402+
403+
If the preprocessor is attached (default), both the trainable model and
404+
tokenizer are exported. To export only the trainable model, set
405+
`self.preprocessor = None` before calling this method, then export the
406+
preprocessor separately via `preprocessor.export_to_transformers(path)`.
407+
408+
Args:
409+
path: str. Path to save the exported model.
410+
"""
411+
from keras_hub.src.utils.transformers.export.hf_exporter import (
412+
export_to_safetensors,
413+
)
414+
415+
export_to_safetensors(self, path)

keras_hub/src/models/causal_lm_preprocessor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,17 @@ def sequence_length(self, value):
180180
self._sequence_length = value
181181
if self.packer is not None:
182182
self.packer.sequence_length = value
183+
184+
def export_to_transformers(self, path):
185+
"""Export the preprocessor to HuggingFace Transformers format.
186+
187+
Args:
188+
path: str. Path to save the exported preprocessor/tokenizer.
189+
"""
190+
if self.tokenizer is None:
191+
raise ValueError("Preprocessor must have a tokenizer for export.")
192+
from keras_hub.src.utils.transformers.export.hf_exporter import (
193+
export_tokenizer,
194+
)
195+
196+
export_tokenizer(self.tokenizer, path)

keras_hub/src/models/causal_lm_preprocessor_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1+
import os
2+
13
import pytest
24

35
from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer
46
from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor
7+
from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import (
8+
GemmaCausalLMPreprocessor,
9+
)
10+
from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer
511
from keras_hub.src.models.gpt2.gpt2_causal_lm_preprocessor import (
612
GPT2CausalLMPreprocessor,
713
)
@@ -43,3 +49,22 @@ def test_from_preset_errors(self):
4349
with self.assertRaises(ValueError):
4450
# No loading on a non-keras model.
4551
GPT2CausalLMPreprocessor.from_preset("hf://spacy/en_core_web_sm")
52+
53+
def test_export_supported_preprocessor(self):
54+
proto = os.path.join(self.get_test_data_dir(), "gemma_export_vocab.spm")
55+
tokenizer = GemmaTokenizer(proto=proto)
56+
preprocessor = GemmaCausalLMPreprocessor(tokenizer=tokenizer)
57+
export_path = os.path.join(self.get_temp_dir(), "export_preprocessor")
58+
preprocessor.export_to_transformers(export_path)
59+
# Basic check: tokenizer config exists
60+
self.assertTrue(
61+
os.path.exists(os.path.join(export_path, "tokenizer_config.json"))
62+
)
63+
64+
def test_export_missing_tokenizer(self):
65+
preprocessor = GemmaCausalLMPreprocessor(tokenizer=None)
66+
export_path = os.path.join(
67+
self.get_temp_dir(), "export_missing_tokenizer"
68+
)
69+
with self.assertRaises(ValueError):
70+
preprocessor.export_to_transformers(export_path)

keras_hub/src/models/task_test.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77

88
from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier
99
from keras_hub.src.models.causal_lm import CausalLM
10+
from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone
11+
from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM
12+
from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import (
13+
GemmaCausalLMPreprocessor,
14+
)
15+
from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer
1016
from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM
1117
from keras_hub.src.models.image_classifier import ImageClassifier
1218
from keras_hub.src.models.preprocessor import Preprocessor
@@ -171,3 +177,85 @@ def test_save_to_preset_custom_backbone_and_preprocessor(self):
171177
restored_task = ImageClassifier.from_preset(save_dir)
172178
actual = restored_task.predict(batch)
173179
self.assertAllClose(expected, actual)
180+
181+
def _create_gemma_for_export_tests(self):
182+
proto = os.path.join(self.get_test_data_dir(), "gemma_export_vocab.spm")
183+
tokenizer = GemmaTokenizer(proto=proto)
184+
backbone = GemmaBackbone(
185+
vocabulary_size=tokenizer.vocabulary_size(),
186+
num_layers=2,
187+
num_query_heads=4,
188+
num_key_value_heads=1,
189+
hidden_dim=512,
190+
intermediate_dim=1024,
191+
head_dim=128,
192+
)
193+
preprocessor = GemmaCausalLMPreprocessor(tokenizer=tokenizer)
194+
causal_lm = GemmaCausalLM(backbone=backbone, preprocessor=preprocessor)
195+
return causal_lm, preprocessor
196+
197+
def test_export_attached(self):
198+
causal_lm, _ = self._create_gemma_for_export_tests()
199+
export_path = os.path.join(self.get_temp_dir(), "export_attached")
200+
causal_lm.export_to_transformers(export_path)
201+
# Basic check: config and tokenizer files exist
202+
self.assertTrue(
203+
os.path.exists(os.path.join(export_path, "config.json"))
204+
)
205+
self.assertTrue(
206+
os.path.exists(os.path.join(export_path, "tokenizer_config.json"))
207+
)
208+
209+
def test_export_attached_with_lm_head(self):
210+
# Since attached export always includes lm_head=True, this test verifies
211+
# the same but explicitly notes it for coverage.
212+
causal_lm, _ = self._create_gemma_for_export_tests()
213+
export_path = os.path.join(
214+
self.get_temp_dir(), "export_attached_lm_head"
215+
)
216+
causal_lm.export_to_transformers(export_path)
217+
# Basic check: config and tokenizer files exist
218+
self.assertTrue(
219+
os.path.exists(os.path.join(export_path, "config.json"))
220+
)
221+
self.assertTrue(
222+
os.path.exists(os.path.join(export_path, "tokenizer_config.json"))
223+
)
224+
225+
def test_export_detached(self):
226+
causal_lm, preprocessor = self._create_gemma_for_export_tests()
227+
export_path_backbone = os.path.join(
228+
self.get_temp_dir(), "export_detached_backbone"
229+
)
230+
export_path_preprocessor = os.path.join(
231+
self.get_temp_dir(), "export_detached_preprocessor"
232+
)
233+
original_preprocessor = causal_lm.preprocessor
234+
causal_lm.preprocessor = None
235+
causal_lm.export_to_transformers(export_path_backbone)
236+
causal_lm.preprocessor = original_preprocessor
237+
preprocessor.export_to_transformers(export_path_preprocessor)
238+
# Basic check: backbone has config, no tokenizer; preprocessor has
239+
# tokenizer config
240+
self.assertTrue(
241+
os.path.exists(os.path.join(export_path_backbone, "config.json"))
242+
)
243+
self.assertFalse(
244+
os.path.exists(
245+
os.path.join(export_path_backbone, "tokenizer_config.json")
246+
)
247+
)
248+
self.assertTrue(
249+
os.path.exists(
250+
os.path.join(export_path_preprocessor, "tokenizer_config.json")
251+
)
252+
)
253+
254+
def test_export_missing_tokenizer(self):
255+
causal_lm, preprocessor = self._create_gemma_for_export_tests()
256+
preprocessor.tokenizer = None
257+
export_path = os.path.join(
258+
self.get_temp_dir(), "export_missing_tokenizer"
259+
)
260+
with self.assertRaises(ValueError):
261+
causal_lm.export_to_transformers(export_path)
237 KB
Binary file not shown.
237 KB
Binary file not shown.

keras_hub/src/tokenizers/tokenizer.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,18 @@ class like `keras_hub.models.Tokenizer.from_preset()`, or from
261261
if cls.backbone_cls != backbone_cls:
262262
cls = find_subclass(preset, cls, backbone_cls)
263263
return loader.load_tokenizer(cls, config_file, **kwargs)
264+
265+
def export_to_transformers(self, path):
266+
"""Export the tokenizer to HuggingFace Transformers format.
267+
268+
This saves tokenizer assets in a format compatible with HuggingFace
269+
Transformers.
270+
271+
Args:
272+
path: str. Path to save the exported tokenizer.
273+
"""
274+
from keras_hub.src.utils.transformers.export.hf_exporter import (
275+
export_tokenizer,
276+
)
277+
278+
export_tokenizer(self, path)

keras_hub/src/tokenizers/tokenizer_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from keras_hub.src.models.albert.albert_tokenizer import AlbertTokenizer
88
from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer
9+
from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer
910
from keras_hub.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer
1011
from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer
1112
from keras_hub.src.tests.test_case import TestCase
@@ -113,3 +114,24 @@ def test_save_to_preset(self, cls, preset_name, tokenizer_type):
113114
# Check config class.
114115
tokenizer_config = load_json(save_dir, TOKENIZER_CONFIG_FILE)
115116
self.assertEqual(cls, check_config_class(tokenizer_config))
117+
118+
def test_export_supported_tokenizer(self):
119+
proto = os.path.join(self.get_test_data_dir(), "gemma_export_vocab.spm")
120+
tokenizer = GemmaTokenizer(proto=proto)
121+
export_path = os.path.join(self.get_temp_dir(), "export_tokenizer")
122+
tokenizer.export_to_transformers(export_path)
123+
# Basic check: tokenizer config exists
124+
self.assertTrue(
125+
os.path.exists(os.path.join(export_path, "tokenizer_config.json"))
126+
)
127+
128+
def test_export_unsupported_tokenizer(self):
129+
proto = os.path.join(self.get_test_data_dir(), "gemma_export_vocab.spm")
130+
131+
class UnsupportedTokenizer(GemmaTokenizer):
132+
pass
133+
134+
tokenizer = UnsupportedTokenizer(proto=proto)
135+
export_path = os.path.join(self.get_temp_dir(), "unsupported_tokenizer")
136+
with self.assertRaises(ValueError):
137+
tokenizer.export_to_transformers(export_path)

0 commit comments

Comments
 (0)