|
7 | 7 |
|
8 | 8 | from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier
|
9 | 9 | from keras_hub.src.models.causal_lm import CausalLM
|
| 10 | +from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone |
| 11 | +from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM |
| 12 | +from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import ( |
| 13 | + GemmaCausalLMPreprocessor, |
| 14 | +) |
| 15 | +from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer |
10 | 16 | from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM
|
11 | 17 | from keras_hub.src.models.image_classifier import ImageClassifier
|
12 | 18 | from keras_hub.src.models.preprocessor import Preprocessor
|
@@ -171,3 +177,85 @@ def test_save_to_preset_custom_backbone_and_preprocessor(self):
|
171 | 177 | restored_task = ImageClassifier.from_preset(save_dir)
|
172 | 178 | actual = restored_task.predict(batch)
|
173 | 179 | self.assertAllClose(expected, actual)
|
| 180 | + |
| 181 | + def _create_gemma_for_export_tests(self): |
| 182 | + proto = os.path.join(self.get_test_data_dir(), "gemma_export_vocab.spm") |
| 183 | + tokenizer = GemmaTokenizer(proto=proto) |
| 184 | + backbone = GemmaBackbone( |
| 185 | + vocabulary_size=tokenizer.vocabulary_size(), |
| 186 | + num_layers=2, |
| 187 | + num_query_heads=4, |
| 188 | + num_key_value_heads=1, |
| 189 | + hidden_dim=512, |
| 190 | + intermediate_dim=1024, |
| 191 | + head_dim=128, |
| 192 | + ) |
| 193 | + preprocessor = GemmaCausalLMPreprocessor(tokenizer=tokenizer) |
| 194 | + causal_lm = GemmaCausalLM(backbone=backbone, preprocessor=preprocessor) |
| 195 | + return causal_lm, preprocessor |
| 196 | + |
| 197 | + def test_export_attached(self): |
| 198 | + causal_lm, _ = self._create_gemma_for_export_tests() |
| 199 | + export_path = os.path.join(self.get_temp_dir(), "export_attached") |
| 200 | + causal_lm.export_to_transformers(export_path) |
| 201 | + # Basic check: config and tokenizer files exist |
| 202 | + self.assertTrue( |
| 203 | + os.path.exists(os.path.join(export_path, "config.json")) |
| 204 | + ) |
| 205 | + self.assertTrue( |
| 206 | + os.path.exists(os.path.join(export_path, "tokenizer_config.json")) |
| 207 | + ) |
| 208 | + |
| 209 | + def test_export_attached_with_lm_head(self): |
| 210 | + # Since attached export always includes lm_head=True, this test verifies |
| 211 | + # the same but explicitly notes it for coverage. |
| 212 | + causal_lm, _ = self._create_gemma_for_export_tests() |
| 213 | + export_path = os.path.join( |
| 214 | + self.get_temp_dir(), "export_attached_lm_head" |
| 215 | + ) |
| 216 | + causal_lm.export_to_transformers(export_path) |
| 217 | + # Basic check: config and tokenizer files exist |
| 218 | + self.assertTrue( |
| 219 | + os.path.exists(os.path.join(export_path, "config.json")) |
| 220 | + ) |
| 221 | + self.assertTrue( |
| 222 | + os.path.exists(os.path.join(export_path, "tokenizer_config.json")) |
| 223 | + ) |
| 224 | + |
| 225 | + def test_export_detached(self): |
| 226 | + causal_lm, preprocessor = self._create_gemma_for_export_tests() |
| 227 | + export_path_backbone = os.path.join( |
| 228 | + self.get_temp_dir(), "export_detached_backbone" |
| 229 | + ) |
| 230 | + export_path_preprocessor = os.path.join( |
| 231 | + self.get_temp_dir(), "export_detached_preprocessor" |
| 232 | + ) |
| 233 | + original_preprocessor = causal_lm.preprocessor |
| 234 | + causal_lm.preprocessor = None |
| 235 | + causal_lm.export_to_transformers(export_path_backbone) |
| 236 | + causal_lm.preprocessor = original_preprocessor |
| 237 | + preprocessor.export_to_transformers(export_path_preprocessor) |
| 238 | + # Basic check: backbone has config, no tokenizer; preprocessor has |
| 239 | + # tokenizer config |
| 240 | + self.assertTrue( |
| 241 | + os.path.exists(os.path.join(export_path_backbone, "config.json")) |
| 242 | + ) |
| 243 | + self.assertFalse( |
| 244 | + os.path.exists( |
| 245 | + os.path.join(export_path_backbone, "tokenizer_config.json") |
| 246 | + ) |
| 247 | + ) |
| 248 | + self.assertTrue( |
| 249 | + os.path.exists( |
| 250 | + os.path.join(export_path_preprocessor, "tokenizer_config.json") |
| 251 | + ) |
| 252 | + ) |
| 253 | + |
| 254 | + def test_export_missing_tokenizer(self): |
| 255 | + causal_lm, preprocessor = self._create_gemma_for_export_tests() |
| 256 | + preprocessor.tokenizer = None |
| 257 | + export_path = os.path.join( |
| 258 | + self.get_temp_dir(), "export_missing_tokenizer" |
| 259 | + ) |
| 260 | + with self.assertRaises(ValueError): |
| 261 | + causal_lm.export_to_transformers(export_path) |
0 commit comments