5959
6060# Config file names.
6161CONFIG_FILE = "config.json"
62+ HF_CONFIG_FILE = "config.json"
6263TOKENIZER_CONFIG_FILE = "tokenizer.json"
6364TASK_CONFIG_FILE = "task.json"
6465PREPROCESSOR_CONFIG_FILE = "preprocessor.json"
6566METADATA_FILE = "metadata.json"
67+ SAFETENSOR_CONFIG_FILE = "model.safetensors.index.json"
6668
6769README_FILE = "README.md"
6870
6971# Weight file names.
7072MODEL_WEIGHTS_FILE = "model.weights.h5"
7173TASK_WEIGHTS_FILE = "task.weights.h5"
74+ SAFETENSOR_FILE = "model.safetensors"
7275
7376# Global state for preset registry.
7477BUILTIN_PRESETS = {}
@@ -324,7 +327,7 @@ def _validate_tokenizer(preset, allow_incomplete=False):
324327 )
325328 config_path = get_file (preset , TOKENIZER_CONFIG_FILE )
326329 try :
327- with open (config_path ) as config_file :
330+ with open (config_path , encoding = "utf-8" ) as config_file :
328331 config = json .load (config_file )
329332 except Exception as e :
330333 raise ValueError (
@@ -357,7 +360,7 @@ def _validate_backbone(preset):
357360 f"`{ CONFIG_FILE } ` is missing from the preset directory `{ preset } `."
358361 )
359362 try :
360- with open (config_path ) as config_file :
363+ with open (config_path , encoding = "utf-8" ) as config_file :
361364 json .load (config_file )
362365 except Exception as e :
363366 raise ValueError (
@@ -530,12 +533,17 @@ def upload_preset(
530533
531534def load_config (preset , config_file = CONFIG_FILE ):
532535 config_path = get_file (preset , config_file )
533- with open (config_path ) as config_file :
536+ with open (config_path , encoding = "utf-8" ) as config_file :
534537 config = json .load (config_file )
535538 return config
536539
537540
538- def validate_metadata (preset ):
541+ def check_format (preset ):
542+ if check_file_exists (preset , SAFETENSOR_FILE ) or check_file_exists (
543+ preset , SAFETENSOR_CONFIG_FILE
544+ ):
545+ return "transformers"
546+
539547 if not check_file_exists (preset , METADATA_FILE ):
540548 raise FileNotFoundError (
541549 f"The preset directory `{ preset } ` doesn't have a file named `{ METADATA_FILE } `, "
@@ -548,6 +556,7 @@ def validate_metadata(preset):
548556 f"`{ METADATA_FILE } ` in the preset directory `{ preset } ` doesn't have `keras_version`. "
549557 "Please verify that the model you are trying to load is a Keras model."
550558 )
559+ return "keras"
551560
552561
553562def load_serialized_object (
@@ -566,7 +575,7 @@ def check_config_class(
566575):
567576 """Validate a preset is being loaded on the correct class."""
568577 config_path = get_file (preset , config_file )
569- with open (config_path ) as config_file :
578+ with open (config_path , encoding = "utf-8" ) as config_file :
570579 config = json .load (config_file )
571580 return keras .saving .get_registered_object (config ["registered_name" ])
572581
0 commit comments