Skip to content

Commit c459519

Browse files
authored
Covert a safetensor checkpoint from Hugging Face hub (#1662)
* chore: adding gemma and llama3 * chore: adding init * chore: removing hard coded values * chore: using backbone properties * chore: reformat * chore: review changes * chore: removing einops with custom np operations * fix: variable name * check: none type for reshape and transpose patterns * chore: fixing the nesting of reshape and transpose patterns * fixing nesting of patterns * chore: gemma weight rearrange fix * chore: adding a hook function to reshape and transpose the hf tensors to match the keras weights * fix: variable to assign * fix: gemma port * chore: adding tests * review comments * adding safetensors as a dep * chore: adding jax memory cleanup * utf 8 encoding * chore: changing tests * chore: fixing tests * fix tests * chore: adding guard rails for None types * Trigger Build * review suggestions * fix raising ValueError * fix error message
1 parent b58b56e commit c459519

18 files changed

+600
-30
lines changed

keras_nlp/src/models/backbone.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,16 @@
2020
from keras_nlp.src.utils.preset_utils import CONFIG_FILE
2121
from keras_nlp.src.utils.preset_utils import MODEL_WEIGHTS_FILE
2222
from keras_nlp.src.utils.preset_utils import check_config_class
23+
from keras_nlp.src.utils.preset_utils import check_format
2324
from keras_nlp.src.utils.preset_utils import get_file
2425
from keras_nlp.src.utils.preset_utils import jax_memory_cleanup
2526
from keras_nlp.src.utils.preset_utils import list_presets
2627
from keras_nlp.src.utils.preset_utils import list_subclasses
2728
from keras_nlp.src.utils.preset_utils import load_serialized_object
2829
from keras_nlp.src.utils.preset_utils import save_metadata
2930
from keras_nlp.src.utils.preset_utils import save_serialized_object
30-
from keras_nlp.src.utils.preset_utils import validate_metadata
3131
from keras_nlp.src.utils.python_utils import classproperty
32+
from keras_nlp.src.utils.transformers.convert import load_transformers_backbone
3233

3334

3435
@keras_nlp_export("keras_nlp.models.Backbone")
@@ -173,7 +174,11 @@ class like `keras_nlp.models.Backbone.from_preset()`, or from
173174
)
174175
```
175176
"""
176-
validate_metadata(preset)
177+
format = check_format(preset)
178+
179+
if format == "transformers":
180+
return load_transformers_backbone(cls, preset, load_weights)
181+
177182
preset_cls = check_config_class(preset)
178183
if not issubclass(preset_cls, cls):
179184
raise ValueError(

keras_nlp/src/models/backbone_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,7 @@ def test_from_preset(self):
5050
def test_from_preset_errors(self):
5151
with self.assertRaises(ValueError):
5252
GPT2Backbone.from_preset("bert_tiny_en_uncased", load_weights=False)
53-
with self.assertRaisesRegex(
54-
FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`"
55-
):
53+
with self.assertRaises(ValueError):
5654
# No loading on a non-keras model.
5755
Backbone.from_preset("hf://google-bert/bert-base-uncased")
5856

keras_nlp/src/models/preprocessor.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222
from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE
2323
from keras_nlp.src.utils.preset_utils import check_config_class
2424
from keras_nlp.src.utils.preset_utils import check_file_exists
25+
from keras_nlp.src.utils.preset_utils import check_format
2526
from keras_nlp.src.utils.preset_utils import list_presets
2627
from keras_nlp.src.utils.preset_utils import list_subclasses
2728
from keras_nlp.src.utils.preset_utils import load_serialized_object
2829
from keras_nlp.src.utils.preset_utils import save_serialized_object
29-
from keras_nlp.src.utils.preset_utils import validate_metadata
3030
from keras_nlp.src.utils.python_utils import classproperty
3131

3232

@@ -128,7 +128,14 @@ def from_preset(
128128
)
129129
```
130130
"""
131-
validate_metadata(preset)
131+
format = check_format(preset)
132+
133+
if format == "transformers":
134+
if cls.tokenizer_cls is None:
135+
raise ValueError("Tokenizer class is None")
136+
tokenizer = cls.tokenizer_cls.from_preset(preset)
137+
return cls(tokenizer=tokenizer, **kwargs)
138+
132139
if cls == Preprocessor:
133140
raise ValueError(
134141
"Do not call `Preprocessor.from_preset()` directly. Instead call a "

keras_nlp/src/models/preprocessor_test.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
RobertaPreprocessor,
2828
)
2929
from keras_nlp.src.tests.test_case import TestCase
30-
from keras_nlp.src.utils.preset_utils import METADATA_FILE
3130
from keras_nlp.src.utils.preset_utils import PREPROCESSOR_CONFIG_FILE
3231
from keras_nlp.src.utils.preset_utils import TOKENIZER_ASSET_DIR
3332
from keras_nlp.src.utils.preset_utils import check_config_class
@@ -67,9 +66,7 @@ def test_from_preset_errors(self):
6766
with self.assertRaises(ValueError):
6867
# No loading on an incorrect class.
6968
BertPreprocessor.from_preset("gpt2_base_en")
70-
with self.assertRaisesRegex(
71-
FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`"
72-
):
69+
with self.assertRaises(ValueError):
7370
# No loading on a non-keras model.
7471
Preprocessor.from_preset("hf://google-bert/bert-base-uncased")
7572

keras_nlp/src/models/task.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@
2828
from keras_nlp.src.utils.preset_utils import TASK_WEIGHTS_FILE
2929
from keras_nlp.src.utils.preset_utils import check_config_class
3030
from keras_nlp.src.utils.preset_utils import check_file_exists
31+
from keras_nlp.src.utils.preset_utils import check_format
3132
from keras_nlp.src.utils.preset_utils import get_file
3233
from keras_nlp.src.utils.preset_utils import jax_memory_cleanup
3334
from keras_nlp.src.utils.preset_utils import list_presets
3435
from keras_nlp.src.utils.preset_utils import list_subclasses
3536
from keras_nlp.src.utils.preset_utils import load_serialized_object
3637
from keras_nlp.src.utils.preset_utils import save_serialized_object
37-
from keras_nlp.src.utils.preset_utils import validate_metadata
3838
from keras_nlp.src.utils.python_utils import classproperty
3939

4040

@@ -187,7 +187,17 @@ def from_preset(
187187
)
188188
```
189189
"""
190-
validate_metadata(preset)
190+
format = check_format(preset)
191+
192+
if format == "transformers":
193+
if cls.backbone_cls is None:
194+
raise ValueError("Backbone class is None")
195+
if cls.preprocessor_cls is None:
196+
raise ValueError("Preprocessor class is None")
197+
198+
backbone = cls.backbone_cls.from_preset(preset)
199+
preprocessor = cls.preprocessor_cls.from_preset(preset)
200+
return cls(backbone=backbone, preprocessor=preprocessor, **kwargs)
191201

192202
if cls == Task:
193203
raise ValueError(

keras_nlp/src/models/task_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,7 @@ def test_from_preset_errors(self):
7979
with self.assertRaises(ValueError):
8080
# No loading on an incorrect class.
8181
BertClassifier.from_preset("gpt2_base_en", load_weights=False)
82-
with self.assertRaisesRegex(
83-
FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`"
84-
):
82+
with self.assertRaises(ValueError):
8583
# No loading on a non-keras model.
8684
CausalLM.from_preset("hf://google-bert/bert-base-uncased")
8785

keras_nlp/src/tokenizers/tokenizer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@
2020
from keras_nlp.src.utils.preset_utils import TOKENIZER_ASSET_DIR
2121
from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE
2222
from keras_nlp.src.utils.preset_utils import check_config_class
23+
from keras_nlp.src.utils.preset_utils import check_format
2324
from keras_nlp.src.utils.preset_utils import get_file
2425
from keras_nlp.src.utils.preset_utils import list_presets
2526
from keras_nlp.src.utils.preset_utils import list_subclasses
2627
from keras_nlp.src.utils.preset_utils import load_serialized_object
2728
from keras_nlp.src.utils.preset_utils import save_serialized_object
2829
from keras_nlp.src.utils.preset_utils import save_tokenizer_assets
29-
from keras_nlp.src.utils.preset_utils import validate_metadata
3030
from keras_nlp.src.utils.python_utils import classproperty
31+
from keras_nlp.src.utils.transformers.convert import load_transformers_tokenizer
3132

3233

3334
@keras_nlp_export(
@@ -215,7 +216,10 @@ class like `keras_nlp.models.Tokenizer.from_preset()`, or from
215216
tokenizer.detokenize([5, 6, 7, 8, 9])
216217
```
217218
"""
218-
validate_metadata(preset)
219+
format = check_format(preset)
220+
if format == "transformers":
221+
return load_transformers_tokenizer(cls, preset)
222+
219223
preset_cls = check_config_class(
220224
preset, config_file=TOKENIZER_CONFIG_FILE
221225
)

keras_nlp/src/tokenizers/tokenizer_test.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from keras_nlp.src.models.roberta.roberta_tokenizer import RobertaTokenizer
3232
from keras_nlp.src.tests.test_case import TestCase
3333
from keras_nlp.src.tokenizers.tokenizer import Tokenizer
34-
from keras_nlp.src.utils.preset_utils import METADATA_FILE
3534
from keras_nlp.src.utils.preset_utils import TOKENIZER_ASSET_DIR
3635
from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE
3736
from keras_nlp.src.utils.preset_utils import check_config_class
@@ -70,9 +69,7 @@ def test_from_preset(self):
7069
def test_from_preset_errors(self):
7170
with self.assertRaises(ValueError):
7271
GPT2Tokenizer.from_preset("bert_tiny_en_uncased")
73-
with self.assertRaisesRegex(
74-
FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`"
75-
):
72+
with self.assertRaises(ValueError):
7673
# No loading on a non-keras model.
7774
Tokenizer.from_preset("hf://google-bert/bert-base-uncased")
7875

keras_nlp/src/utils/preset_utils.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,19 @@
5959

6060
# Config file names.
6161
CONFIG_FILE = "config.json"
62+
HF_CONFIG_FILE = "config.json"
6263
TOKENIZER_CONFIG_FILE = "tokenizer.json"
6364
TASK_CONFIG_FILE = "task.json"
6465
PREPROCESSOR_CONFIG_FILE = "preprocessor.json"
6566
METADATA_FILE = "metadata.json"
67+
SAFETENSOR_CONFIG_FILE = "model.safetensors.index.json"
6668

6769
README_FILE = "README.md"
6870

6971
# Weight file names.
7072
MODEL_WEIGHTS_FILE = "model.weights.h5"
7173
TASK_WEIGHTS_FILE = "task.weights.h5"
74+
SAFETENSOR_FILE = "model.safetensors"
7275

7376
# Global state for preset registry.
7477
BUILTIN_PRESETS = {}
@@ -324,7 +327,7 @@ def _validate_tokenizer(preset, allow_incomplete=False):
324327
)
325328
config_path = get_file(preset, TOKENIZER_CONFIG_FILE)
326329
try:
327-
with open(config_path) as config_file:
330+
with open(config_path, encoding="utf-8") as config_file:
328331
config = json.load(config_file)
329332
except Exception as e:
330333
raise ValueError(
@@ -357,7 +360,7 @@ def _validate_backbone(preset):
357360
f"`{CONFIG_FILE}` is missing from the preset directory `{preset}`."
358361
)
359362
try:
360-
with open(config_path) as config_file:
363+
with open(config_path, encoding="utf-8") as config_file:
361364
json.load(config_file)
362365
except Exception as e:
363366
raise ValueError(
@@ -530,12 +533,17 @@ def upload_preset(
530533

531534
def load_config(preset, config_file=CONFIG_FILE):
532535
config_path = get_file(preset, config_file)
533-
with open(config_path) as config_file:
536+
with open(config_path, encoding="utf-8") as config_file:
534537
config = json.load(config_file)
535538
return config
536539

537540

538-
def validate_metadata(preset):
541+
def check_format(preset):
542+
if check_file_exists(preset, SAFETENSOR_FILE) or check_file_exists(
543+
preset, SAFETENSOR_CONFIG_FILE
544+
):
545+
return "transformers"
546+
539547
if not check_file_exists(preset, METADATA_FILE):
540548
raise FileNotFoundError(
541549
f"The preset directory `{preset}` doesn't have a file named `{METADATA_FILE}`, "
@@ -548,6 +556,7 @@ def validate_metadata(preset):
548556
f"`{METADATA_FILE}` in the preset directory `{preset}` doesn't have `keras_version`. "
549557
"Please verify that the model you are trying to load is a Keras model."
550558
)
559+
return "keras"
551560

552561

553562
def load_serialized_object(
@@ -566,7 +575,7 @@ def check_config_class(
566575
):
567576
"""Validate a preset is being loaded on the correct class."""
568577
config_path = get_file(preset, config_file)
569-
with open(config_path) as config_file:
578+
with open(config_path, encoding="utf-8") as config_file:
570579
config = json.load(config_file)
571580
return keras.saving.get_registered_object(config["registered_name"])
572581

keras_nlp/src/utils/preset_utils_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from keras_nlp.src.utils.preset_utils import CONFIG_FILE
2727
from keras_nlp.src.utils.preset_utils import METADATA_FILE
2828
from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE
29-
from keras_nlp.src.utils.preset_utils import validate_metadata
29+
from keras_nlp.src.utils.preset_utils import check_format
3030

3131

3232
class PresetUtilsTest(TestCase):
@@ -100,7 +100,7 @@ def test_missing_metadata(self):
100100
with self.assertRaisesRegex(
101101
FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`"
102102
):
103-
validate_metadata(preset_dir)
103+
check_format(preset_dir)
104104

105105
def test_incorrect_metadata(self):
106106
temp_dir = self.get_temp_dir()
@@ -112,4 +112,4 @@ def test_incorrect_metadata(self):
112112
json.dump(data, f)
113113

114114
with self.assertRaisesRegex(ValueError, "doesn't have `keras_version`"):
115-
validate_metadata(preset_dir)
115+
check_format(preset_dir)

0 commit comments

Comments
 (0)