diff --git a/keras_hub/api/__init__.py b/keras_hub/api/__init__.py index 2aa98bf3f9..810f8fa921 100644 --- a/keras_hub/api/__init__.py +++ b/keras_hub/api/__init__.py @@ -4,6 +4,7 @@ since your modifications would be overwritten. """ +from keras_hub import export as export from keras_hub import layers as layers from keras_hub import metrics as metrics from keras_hub import models as models diff --git a/keras_hub/api/export/__init__.py b/keras_hub/api/export/__init__.py new file mode 100644 index 0000000000..25d1cc446a --- /dev/null +++ b/keras_hub/api/export/__init__.py @@ -0,0 +1,28 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras_hub.src.export.configs import ( + CausalLMExporterConfig as CausalLMExporterConfig, +) +from keras_hub.src.export.configs import ( + ImageClassifierExporterConfig as ImageClassifierExporterConfig, +) +from keras_hub.src.export.configs import ( + ImageSegmenterExporterConfig as ImageSegmenterExporterConfig, +) +from keras_hub.src.export.configs import ( + ObjectDetectorExporterConfig as ObjectDetectorExporterConfig, +) +from keras_hub.src.export.configs import ( + Seq2SeqLMExporterConfig as Seq2SeqLMExporterConfig, +) +from keras_hub.src.export.configs import ( + TextClassifierExporterConfig as TextClassifierExporterConfig, +) +from keras_hub.src.export.configs import ( + TextModelExporterConfig as TextModelExporterConfig, +) +from keras_hub.src.export.litert import LitertExporter as LitertExporter diff --git a/keras_hub/src/export/__init__.py b/keras_hub/src/export/__init__.py new file mode 100644 index 0000000000..4c32e4411d --- /dev/null +++ b/keras_hub/src/export/__init__.py @@ -0,0 +1,9 @@ +from keras_hub.src.export.base import ExporterRegistry +from keras_hub.src.export.base import KerasHubExporter +from keras_hub.src.export.base import KerasHubExporterConfig +from keras_hub.src.export.configs import CausalLMExporterConfig +from keras_hub.src.export.configs import Seq2SeqLMExporterConfig +from keras_hub.src.export.configs import TextClassifierExporterConfig +from keras_hub.src.export.configs import TextModelExporterConfig +from keras_hub.src.export.litert import LitertExporter +from keras_hub.src.export.litert import export_litert diff --git a/keras_hub/src/export/base.py b/keras_hub/src/export/base.py new file mode 100644 index 0000000000..9e4a1cf8e5 --- /dev/null +++ b/keras_hub/src/export/base.py @@ -0,0 +1,310 @@ +"""Base classes for Keras-Hub model exporters. + +This module provides the foundation for exporting Keras-Hub models to various +formats. It follows the Optimum pattern of having different exporters for +different model types and formats. +""" + +from abc import ABC +from abc import abstractmethod + +try: + import keras + + KERAS_AVAILABLE = True +except ImportError: + KERAS_AVAILABLE = False + keras = None + + +class KerasHubExporterConfig(ABC): + """Base configuration class for Keras-Hub model exporters. + + This class defines the interface for exporter configurations that specify + how different types of Keras-Hub models should be exported. + """ + + # Model type this exporter handles (e.g., "causal_lm", "text_classifier") + MODEL_TYPE = None + + # Expected input structure for this model type + EXPECTED_INPUTS = [] + + # Default sequence length if not specified + DEFAULT_SEQUENCE_LENGTH = 128 + + def __init__(self, model, **kwargs): + """Initialize the exporter configuration. + + Args: + model: `keras.Model`. The Keras-Hub model to export. + **kwargs: Additional configuration parameters. + """ + self.model = model + self.config_kwargs = kwargs + self._validate_model() + + def _validate_model(self): + """Validate that the model is compatible with this exporter.""" + if not self._is_model_compatible(): + raise ValueError( + f"Model {self.model.__class__.__name__} is not compatible " + f"with {self.__class__.__name__}" + ) + + @abstractmethod + def _is_model_compatible(self): + """Check if the model is compatible with this exporter. + + Returns: + bool: True if compatible, False otherwise + """ + pass + + @abstractmethod + def get_input_signature(self, sequence_length=None): + """Get the input signature for this model type. + + Args: + sequence_length: `int` or `None`. Optional sequence length for + input tensors. + + Returns: + A dictionary mapping input names to their tensor specifications. + """ + pass + + def get_dummy_inputs(self, sequence_length=None): + """Generate dummy inputs for model building and testing. + + Args: + sequence_length: `int` or `None`. Optional sequence length for + dummy inputs. + + Returns: + A dictionary of dummy inputs. + """ + if sequence_length is None: + sequence_length = self.DEFAULT_SEQUENCE_LENGTH + + dummy_inputs = {} + + # Common inputs for most Keras-Hub models + if "token_ids" in self.EXPECTED_INPUTS: + dummy_inputs["token_ids"] = keras.ops.ones( + (1, sequence_length), dtype="int32" + ) + if "padding_mask" in self.EXPECTED_INPUTS: + dummy_inputs["padding_mask"] = keras.ops.ones( + (1, sequence_length), dtype="bool" + ) + + # Encoder-decoder specific inputs + if "encoder_token_ids" in self.EXPECTED_INPUTS: + dummy_inputs["encoder_token_ids"] = keras.ops.ones( + (1, sequence_length), dtype="int32" + ) + if "encoder_padding_mask" in self.EXPECTED_INPUTS: + dummy_inputs["encoder_padding_mask"] = keras.ops.ones( + (1, sequence_length), dtype="bool" + ) + if "decoder_token_ids" in self.EXPECTED_INPUTS: + dummy_inputs["decoder_token_ids"] = keras.ops.ones( + (1, sequence_length), dtype="int32" + ) + if "decoder_padding_mask" in self.EXPECTED_INPUTS: + dummy_inputs["decoder_padding_mask"] = keras.ops.ones( + (1, sequence_length), dtype="bool" + ) + + return dummy_inputs + + +class KerasHubExporter(ABC): + """Base class for Keras-Hub model exporters. + + This class provides the common interface for exporting Keras-Hub models + to different formats (LiteRT, ONNX, etc.). + """ + + def __init__(self, config, **kwargs): + """Initialize the exporter. + + Args: + config: `KerasHubExporterConfig`. Exporter configuration specifying + model type and parameters. + **kwargs: Additional exporter-specific parameters. + """ + self.config = config + self.model = config.model + self.export_kwargs = kwargs + + @abstractmethod + def export(self, filepath): + """Export the model to the specified filepath. + + Args: + filepath: `str`. Path where to save the exported model. + """ + pass + + def _ensure_model_built(self, param=None): + """Ensure the model is properly built with correct input structure. + + This method builds the model using model.build() with input shapes. + This creates the necessary variables and initializes the model structure + for export, avoiding the need for dummy forward passes. + + Note: We don't check model.built because it can be True even if the + model isn't properly initialized with the correct input structure. + + Args: + param: `int` or `None`. Optional parameter for input signature + (e.g., sequence_length for text models, image_size for vision + models). + """ + # Get input signature (returns dict of InputSpec objects) + input_signature = self.config.get_input_signature(param) + + # Extract shapes from InputSpec objects + input_shapes = {} + for name, spec in input_signature.items(): + if hasattr(spec, "shape"): + input_shapes[name] = spec.shape + else: + # Fallback for unexpected formats + input_shapes[name] = spec + + try: + # Build the model using shapes only (no actual data allocation) + # This creates variables and initializes the model structure + self.model.build(input_shape=input_shapes) + except Exception as e: + # Fallback to forward pass approach if build() fails + # This maintains backward compatibility for models that don't + # support shape-based building + try: + dummy_inputs = self.config.get_dummy_inputs(param) + _ = self.model(dummy_inputs, training=False) + except Exception as fallback_error: + raise ValueError( + f"Failed to build model with both shape-based building " + f"({e}) and forward pass ({fallback_error}). Please ensure " + f"the model is properly constructed." + ) + + +class ExporterRegistry: + """Registry for mapping model types to their appropriate exporters.""" + + _configs = {} + _exporters = {} + + @classmethod + def register_config(cls, model_type, config_class): + """Register a configuration class for a model type. + + Args: + model_type: The model type (e.g., "causal_lm") + config_class: The configuration class + """ + cls._configs[model_type] = config_class + + @classmethod + def register_exporter(cls, format_name, exporter_class): + """Register an exporter class for a format. + + Args: + format_name: The export format (e.g., "litert") + exporter_class: The exporter class + """ + cls._exporters[format_name] = exporter_class + + @classmethod + def get_config_for_model(cls, model): + """Get the appropriate configuration for a model. + + Args: + model: The Keras-Hub model + + Returns: + KerasHubExporterConfig: An appropriate exporter configuration + instance + + Raises: + ValueError: If no configuration is found for the model type + """ + model_type = cls._detect_model_type(model) + + if model_type not in cls._configs: + raise ValueError( + f"No configuration found for model type: {model_type}" + ) + + config_class = cls._configs[model_type] + return config_class(model) + + @classmethod + def get_exporter(cls, format_name, config, **kwargs): + """Get an exporter for the specified format. + + Args: + format_name: The export format + config: The exporter configuration + **kwargs: Additional parameters for the exporter + + Returns: + KerasHubExporter: An appropriate exporter instance + + Raises: + ValueError: If no exporter is found for the format + """ + if format_name not in cls._exporters: + raise ValueError(f"No exporter found for format: {format_name}") + + exporter_class = cls._exporters[format_name] + return exporter_class(config, **kwargs) + + @classmethod + def _detect_model_type(cls, model): + """Detect the model type from the model instance. + + Args: + model: The Keras-Hub model + + Returns: + str: The detected model type + """ + # Import here to avoid circular imports + try: + from keras_hub.src.models.causal_lm import CausalLM + from keras_hub.src.models.image_segmenter import ImageSegmenter + from keras_hub.src.models.object_detector import ObjectDetector + from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM + except ImportError: + CausalLM = None + Seq2SeqLM = None + ObjectDetector = None + ImageSegmenter = None + + model_class_name = model.__class__.__name__ + + if CausalLM and isinstance(model, CausalLM): + return "causal_lm" + elif "TextClassifier" in model_class_name: + return "text_classifier" + elif Seq2SeqLM and isinstance(model, Seq2SeqLM): + return "seq2seq_lm" + elif "ImageClassifier" in model_class_name: + return "image_classifier" + elif ObjectDetector and isinstance(model, ObjectDetector): + return "object_detector" + elif "ObjectDetector" in model_class_name: + return "object_detector" + elif ImageSegmenter and isinstance(model, ImageSegmenter): + return "image_segmenter" + elif "ImageSegmenter" in model_class_name: + return "image_segmenter" + else: + # Default to text model for generic Keras-Hub models + return "text_model" diff --git a/keras_hub/src/export/configs.py b/keras_hub/src/export/configs.py new file mode 100644 index 0000000000..97f5a43721 --- /dev/null +++ b/keras_hub/src/export/configs.py @@ -0,0 +1,532 @@ +"""Configuration classes for different Keras-Hub model types. + +This module provides specific configurations for exporting different types +of Keras-Hub models, following the Optimum pattern. +""" + +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.export.base import KerasHubExporterConfig + + +@keras_hub_export("keras_hub.export.CausalLMExporterConfig") +class CausalLMExporterConfig(KerasHubExporterConfig): + """Exporter configuration for Causal Language Models (GPT, LLaMA, etc.).""" + + MODEL_TYPE = "causal_lm" + EXPECTED_INPUTS = ["token_ids", "padding_mask"] + DEFAULT_SEQUENCE_LENGTH = 128 + + def _is_model_compatible(self): + """Check if model is a causal language model. + + Returns: + bool: True if compatible, False otherwise + """ + try: + from keras_hub.src.models.causal_lm import CausalLM + + return isinstance(self.model, CausalLM) + except ImportError: + # Fallback to class name checking + return "CausalLM" in self.model.__class__.__name__ + + def get_input_signature(self, sequence_length=None): + """Get input signature for causal LM models. + + Args: + sequence_length: Optional sequence length. If None, uses default. + + Returns: + Dict[str, Any]: Dictionary mapping input names to their + specifications + """ + if sequence_length is None: + # Get from preprocessor or use default + if hasattr(self.model, "preprocessor") and self.model.preprocessor: + sequence_length = getattr( + self.model.preprocessor, + "sequence_length", + self.DEFAULT_SEQUENCE_LENGTH, + ) + else: + sequence_length = self.DEFAULT_SEQUENCE_LENGTH + + return { + "token_ids": keras.layers.InputSpec( + shape=(None, sequence_length), dtype="int32", name="token_ids" + ), + "padding_mask": keras.layers.InputSpec( + shape=(None, sequence_length), dtype="bool", name="padding_mask" + ), + } + + +@keras_hub_export("keras_hub.export.TextClassifierExporterConfig") +class TextClassifierExporterConfig(KerasHubExporterConfig): + """Exporter configuration for Text Classification models.""" + + MODEL_TYPE = "text_classifier" + EXPECTED_INPUTS = ["token_ids", "padding_mask"] + DEFAULT_SEQUENCE_LENGTH = 128 + + def _is_model_compatible(self): + """Check if model is a text classifier. + + Returns: + bool: True if compatible, False otherwise + """ + return "TextClassifier" in self.model.__class__.__name__ + + def get_input_signature(self, sequence_length=None): + """Get input signature for text classifier models. + + Args: + sequence_length: Optional sequence length. If None, uses default. + + Returns: + Dict[str, Any]: Dictionary mapping input names to their + specifications + """ + if sequence_length is None: + # Get from preprocessor or use default + if hasattr(self.model, "preprocessor") and self.model.preprocessor: + sequence_length = getattr( + self.model.preprocessor, + "sequence_length", + self.DEFAULT_SEQUENCE_LENGTH, + ) + else: + sequence_length = self.DEFAULT_SEQUENCE_LENGTH + + return { + "token_ids": keras.layers.InputSpec( + shape=(None, sequence_length), dtype="int32", name="token_ids" + ), + "padding_mask": keras.layers.InputSpec( + shape=(None, sequence_length), dtype="bool", name="padding_mask" + ), + } + + +@keras_hub_export("keras_hub.export.Seq2SeqLMExporterConfig") +class Seq2SeqLMExporterConfig(KerasHubExporterConfig): + """Exporter configuration for Sequence-to-Sequence Language Models.""" + + MODEL_TYPE = "seq2seq_lm" + EXPECTED_INPUTS = [ + "encoder_token_ids", + "encoder_padding_mask", + "decoder_token_ids", + "decoder_padding_mask", + ] + DEFAULT_SEQUENCE_LENGTH = 128 + + def _is_model_compatible(self): + """Check if model is a seq2seq language model. + + Returns: + bool: True if compatible, False otherwise + """ + try: + from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM + + return isinstance(self.model, Seq2SeqLM) + except ImportError: + return "Seq2SeqLM" in self.model.__class__.__name__ + + def get_input_signature(self, sequence_length=None): + """Get input signature for seq2seq models. + + Args: + sequence_length: Optional sequence length. If None, uses default. + + Returns: + Dict[str, Any]: Dictionary mapping input names to their + specifications + """ + if sequence_length is None: + # Get from preprocessor or use default + if hasattr(self.model, "preprocessor") and self.model.preprocessor: + sequence_length = getattr( + self.model.preprocessor, + "sequence_length", + self.DEFAULT_SEQUENCE_LENGTH, + ) + else: + sequence_length = self.DEFAULT_SEQUENCE_LENGTH + + return { + "encoder_token_ids": keras.layers.InputSpec( + shape=(None, sequence_length), + dtype="int32", + name="encoder_token_ids", + ), + "encoder_padding_mask": keras.layers.InputSpec( + shape=(None, sequence_length), + dtype="bool", + name="encoder_padding_mask", + ), + "decoder_token_ids": keras.layers.InputSpec( + shape=(None, sequence_length), + dtype="int32", + name="decoder_token_ids", + ), + "decoder_padding_mask": keras.layers.InputSpec( + shape=(None, sequence_length), + dtype="bool", + name="decoder_padding_mask", + ), + } + + +@keras_hub_export("keras_hub.export.TextModelExporterConfig") +class TextModelExporterConfig(KerasHubExporterConfig): + """Generic exporter configuration for text models.""" + + MODEL_TYPE = "text_model" + EXPECTED_INPUTS = ["token_ids", "padding_mask"] + DEFAULT_SEQUENCE_LENGTH = 128 + + def _is_model_compatible(self): + """Check if model is a text model (fallback). + + Returns: + bool: True if compatible, False otherwise + """ + # This is a fallback config for text models that don't fit other + # categories + return ( + hasattr(self.model, "preprocessor") + and self.model.preprocessor + and hasattr(self.model.preprocessor, "tokenizer") + ) + + def get_input_signature(self, sequence_length=None): + """Get input signature for generic text models. + + Args: + sequence_length: Optional sequence length. If None, uses default. + + Returns: + Dict[str, Any]: Dictionary mapping input names to their + specifications + """ + if sequence_length is None: + # Get from preprocessor or use default + if hasattr(self.model, "preprocessor") and self.model.preprocessor: + sequence_length = getattr( + self.model.preprocessor, + "sequence_length", + self.DEFAULT_SEQUENCE_LENGTH, + ) + else: + sequence_length = self.DEFAULT_SEQUENCE_LENGTH + + return { + "token_ids": keras.layers.InputSpec( + shape=(None, sequence_length), dtype="int32", name="token_ids" + ), + "padding_mask": keras.layers.InputSpec( + shape=(None, sequence_length), dtype="bool", name="padding_mask" + ), + } + + +@keras_hub_export("keras_hub.export.ImageClassifierExporterConfig") +class ImageClassifierExporterConfig(KerasHubExporterConfig): + """Exporter configuration for Image Classification models.""" + + MODEL_TYPE = "image_classifier" + EXPECTED_INPUTS = ["images"] + + def _is_model_compatible(self): + """Check if model is an image classifier. + Returns: + bool: True if compatible, False otherwise + """ + return "ImageClassifier" in self.model.__class__.__name__ + + def get_input_signature(self, image_size=None): + """Get input signature for image classifier models. + Args: + image_size: Optional image size. If None, inferred from model. + Returns: + Dict[str, Any]: Dictionary mapping input names to their + specifications + """ + if image_size is None: + # Get from preprocessor + if hasattr(self.model, "preprocessor") and self.model.preprocessor: + if hasattr(self.model.preprocessor, "image_size"): + image_size = self.model.preprocessor.image_size + + # Try to infer from model inputs + if ( + image_size is None + and hasattr(self.model, "inputs") + and self.model.inputs + ): + input_shape = self.model.inputs[0].shape + if ( + len(input_shape) == 4 + and input_shape[1] is not None + and input_shape[2] is not None + ): + image_size = (input_shape[1], input_shape[2]) + + if image_size is None: + raise ValueError( + "Could not determine image size from model. " + "Model should have a preprocessor with image_size " + "attribute, or model inputs should have concrete shapes." + ) + + if isinstance(image_size, int): + image_size = (image_size, image_size) + + # Get input dtype + dtype = "float32" + if hasattr(self.model, "inputs") and self.model.inputs: + model_dtype = self.model.inputs[0].dtype + dtype = ( + model_dtype.name + if hasattr(model_dtype, "name") + else model_dtype + ) + + return { + "images": keras.layers.InputSpec( + shape=(None, *image_size, 3), + dtype=dtype, + name="images", + ), + } + + def get_dummy_inputs(self, image_size=None): + """Generate dummy inputs for image classifier models. + + Args: + image_size: Optional image size. If None, inferred from model. + + Returns: + Dict[str, Any]: Dictionary of dummy inputs + """ + if image_size is None: + # Get image size using same logic as get_input_signature + if hasattr(self.model, "preprocessor") and self.model.preprocessor: + if hasattr(self.model.preprocessor, "image_size"): + image_size = self.model.preprocessor.image_size + if ( + image_size is None + and hasattr(self.model, "inputs") + and self.model.inputs + ): + input_shape = self.model.inputs[0].shape + if ( + len(input_shape) == 4 + and input_shape[1] is not None + and input_shape[2] is not None + ): + image_size = (input_shape[1], input_shape[2]) + + if isinstance(image_size, int): + image_size = (image_size, image_size) + + dummy_inputs = {} + if "images" in self.EXPECTED_INPUTS: + dummy_inputs["images"] = keras.ops.ones( + (1, *image_size, 3), dtype="float32" + ) + + return dummy_inputs + + +@keras_hub_export("keras_hub.export.ObjectDetectorExporterConfig") +class ObjectDetectorExporterConfig(KerasHubExporterConfig): + """Exporter configuration for Object Detection models.""" + + MODEL_TYPE = "object_detector" + EXPECTED_INPUTS = ["images", "image_shape"] + + def _is_model_compatible(self): + """Check if model is an object detector. + Returns: + bool: True if compatible, False otherwise + """ + return "ObjectDetector" in self.model.__class__.__name__ + + def get_input_signature(self, image_size=None): + """Get input signature for object detector models. + Args: + image_size: Optional image size. If None, inferred from model. + Returns: + Dict[str, Any]: Dictionary mapping input names to their + specifications + """ + if image_size is None: + # Get from preprocessor + if hasattr(self.model, "preprocessor") and self.model.preprocessor: + if hasattr(self.model.preprocessor, "image_size"): + image_size = self.model.preprocessor.image_size + + # Try to infer from model inputs + if ( + image_size is None + and hasattr(self.model, "inputs") + and self.model.inputs + ): + input_shape = self.model.inputs[0].shape + if ( + len(input_shape) == 4 + and input_shape[1] is not None + and input_shape[2] is not None + ): + image_size = (input_shape[1], input_shape[2]) + + if image_size is None: + raise ValueError( + "Could not determine image size from model. " + "Model should have a preprocessor with image_size " + "attribute, or model inputs should have concrete shapes." + ) + + if isinstance(image_size, int): + image_size = (image_size, image_size) + + # Get input dtype + dtype = "float32" + if hasattr(self.model, "inputs") and self.model.inputs: + model_dtype = self.model.inputs[0].dtype + dtype = ( + model_dtype.name + if hasattr(model_dtype, "name") + else model_dtype + ) + + return { + "images": keras.layers.InputSpec( + shape=(None, *image_size, 3), + dtype=dtype, + name="images", + ), + "image_shape": keras.layers.InputSpec( + shape=(None, 2), dtype="int32", name="image_shape" + ), + } + + def get_dummy_inputs(self, image_size=None): + """Generate dummy inputs for object detector models. + + Args: + image_size: Optional image size. If None, inferred from model. + + Returns: + Dict[str, Any]: Dictionary of dummy inputs + """ + if image_size is None: + # Get image size using same logic as get_input_signature + if hasattr(self.model, "preprocessor") and self.model.preprocessor: + if hasattr(self.model.preprocessor, "image_size"): + image_size = self.model.preprocessor.image_size + if ( + image_size is None + and hasattr(self.model, "inputs") + and self.model.inputs + ): + input_shape = self.model.inputs[0].shape + if ( + len(input_shape) == 4 + and input_shape[1] is not None + and input_shape[2] is not None + ): + image_size = (input_shape[1], input_shape[2]) + + if isinstance(image_size, int): + image_size = (image_size, image_size) + + dummy_inputs = {} + + # Create dummy image input + dummy_inputs["images"] = keras.ops.random_uniform( + (1, *image_size, 3), dtype="float32" + ) + + # Create dummy image shape input + dummy_inputs["image_shape"] = keras.ops.constant( + [[image_size[0], image_size[1]]], dtype="int32" + ) + + return dummy_inputs + + +@keras_hub_export("keras_hub.export.ImageSegmenterExporterConfig") +class ImageSegmenterExporterConfig(KerasHubExporterConfig): + """Exporter configuration for Image Segmentation models.""" + + MODEL_TYPE = "image_segmenter" + EXPECTED_INPUTS = ["images"] + + def _is_model_compatible(self): + """Check if model is an image segmenter. + Returns: + bool: True if compatible, False otherwise + """ + return "ImageSegmenter" in self.model.__class__.__name__ + + def get_input_signature(self, image_size=None): + """Get input signature for image segmenter models. + Args: + image_size: Optional image size. If None, inferred from model. + Returns: + Dict[str, Any]: Dictionary mapping input names to their + specifications + """ + if image_size is None: + # Get from preprocessor + if hasattr(self.model, "preprocessor") and self.model.preprocessor: + if hasattr(self.model.preprocessor, "image_size"): + image_size = self.model.preprocessor.image_size + + # Try to infer from model inputs + if ( + image_size is None + and hasattr(self.model, "inputs") + and self.model.inputs + ): + input_shape = self.model.inputs[0].shape + if ( + len(input_shape) == 4 + and input_shape[1] is not None + and input_shape[2] is not None + ): + image_size = (input_shape[1], input_shape[2]) + + if image_size is None: + raise ValueError( + "Could not determine image size from model. " + "Model should have a preprocessor with image_size " + "attribute, or model inputs should have concrete shapes." + ) + + if isinstance(image_size, int): + image_size = (image_size, image_size) + + # Get input dtype + dtype = "float32" + if hasattr(self.model, "inputs") and self.model.inputs: + model_dtype = self.model.inputs[0].dtype + dtype = ( + model_dtype.name + if hasattr(model_dtype, "name") + else model_dtype + ) + + return { + "images": keras.layers.InputSpec( + shape=(None, *image_size, 3), + dtype=dtype, + name="images", + ), + } diff --git a/keras_hub/src/export/litert.py b/keras_hub/src/export/litert.py new file mode 100644 index 0000000000..8063674faf --- /dev/null +++ b/keras_hub/src/export/litert.py @@ -0,0 +1,308 @@ +"""LiteRT exporter for Keras-Hub models. + +This module provides LiteRT export functionality specifically designed for +Keras-Hub models, handling their unique input structures and requirements. +""" + +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.export.base import KerasHubExporter + +try: + from keras.src.export.litert import LitertExporter as KerasLitertExporter + + KERAS_LITE_RT_AVAILABLE = True +except ImportError: + KERAS_LITE_RT_AVAILABLE = False + KerasLitertExporter = None + + +@keras_hub_export("keras_hub.export.LitertExporter") +class LitertExporter(KerasHubExporter): + """LiteRT exporter for Keras-Hub models. + + This exporter handles the conversion of Keras-Hub models to TensorFlow Lite + format, properly managing the dictionary input structures that Keras-Hub + models expect. + """ + + def __init__( + self, + config, + max_sequence_length=None, + aot_compile_targets=None, + verbose=None, + **kwargs, + ): + """Initialize the LiteRT exporter. + + Args: + config: Exporter configuration for the model + max_sequence_length: Maximum sequence length for text models + aot_compile_targets: List of AOT compilation targets + verbose: Whether to print progress messages. Defaults to `None`, + which will use `True`. + **kwargs: Additional arguments passed to the underlying exporter + """ + super().__init__(config, **kwargs) + + if not KERAS_LITE_RT_AVAILABLE: + raise ImportError( + "Keras LiteRT exporter is not available. " + "Make sure you have Keras with LiteRT support installed." + ) + + self.max_sequence_length = max_sequence_length + self.aot_compile_targets = aot_compile_targets + self.verbose = verbose if verbose is not None else True + + def export(self, filepath): + """Export the Keras-Hub model to LiteRT format. + + Args: + filepath: Path where to save the exported model (without extension) + """ + from keras.src.utils import io_utils + + if self.verbose: + io_utils.print_msg( + f"Starting LiteRT export for {self.config.MODEL_TYPE} model" + ) + + # For text models, use sequence_length; for other models, use None + is_text_model = self.config.MODEL_TYPE in [ + "causal_lm", + "text_classifier", + "seq2seq_lm", + ] + param = self.max_sequence_length if is_text_model else None + + # Ensure model is built + self._ensure_model_built(param) + + # Get input signature + input_signature = self.config.get_input_signature(param) + + # Create a wrapper that adapts the Keras-Hub model to work with Keras + # LiteRT exporter + wrapped_model = self._create_export_wrapper() + + # Convert input signature to list format expected by Keras exporter + if isinstance(input_signature, dict): + # Extract specs in the order expected by the model + signature_list = [] + for input_name in self.config.EXPECTED_INPUTS: + if input_name in input_signature: + signature_list.append(input_signature[input_name]) + input_signature = signature_list + + # Create the Keras LiteRT exporter with the wrapped model + keras_exporter = KerasLitertExporter( + wrapped_model, + input_signature=input_signature, + aot_compile_targets=self.aot_compile_targets, + verbose=self.verbose, + **self.export_kwargs, + ) + + try: + # Export using the Keras exporter + keras_exporter.export(filepath) + + if self.verbose: + io_utils.print_msg( + f"Export completed successfully to: {filepath}.tflite" + ) + + except Exception as e: + raise RuntimeError(f"LiteRT export failed: {e}") from e + + def _create_export_wrapper(self): + """Create a wrapper model that handles the input structure conversion. + + This wrapper converts between the list-based inputs that Keras LiteRT + exporter provides and the dictionary-based inputs that Keras-Hub models + expect. + """ + + class KerasHubModelWrapper(keras.Model): + """Wrapper that adapts Keras-Hub models for export.""" + + def __init__( + self, keras_hub_model, expected_inputs, input_signature + ): + super().__init__() + self.keras_hub_model = keras_hub_model + self.expected_inputs = expected_inputs + self.input_signature = input_signature + + # Create Input layers based on the input signature + self._input_layers = [] + for input_name in expected_inputs: + if input_name in input_signature: + spec = input_signature[input_name] + # Ensure we preserve the correct dtype + input_layer = keras.layers.Input( + shape=spec.shape[1:], # Remove batch dimension + dtype=spec.dtype, + name=input_name, + ) + self._input_layers.append(input_layer) + + # Store references to the original model's variables + self._variables = keras_hub_model.variables + self._trainable_variables = keras_hub_model.trainable_variables + self._non_trainable_variables = ( + keras_hub_model.non_trainable_variables + ) + + @property + def variables(self): + return self._variables + + @property + def trainable_variables(self): + return self._trainable_variables + + @property + def non_trainable_variables(self): + return self._non_trainable_variables + + @property + def inputs(self): + """Return the input layers for the Keras exporter to use.""" + return self._input_layers + + def call(self, inputs, training=None, mask=None): + """Convert list inputs to dictionary format and call the + original model.""" + if isinstance(inputs, dict): + # Already in dictionary format + return self.keras_hub_model( + inputs, training=training, mask=mask + ) + + # Convert list inputs to dictionary format + if not isinstance(inputs, (list, tuple)): + inputs = [inputs] + + # For image classifiers, try the direct tensor approach first + # since most Keras-Hub vision models expect single tensor inputs + if ( + len(self.expected_inputs) == 1 + and self.expected_inputs[0] == "images" + ): + try: + return self.keras_hub_model( + inputs[0], training=training, mask=mask + ) + except Exception: + # Fall back to dictionary approach if that fails + pass + + # For LiteRT export, we need to handle the fact that different + # Keras Hub models expect inputs in different formats. Some + # expect dictionaries, others expect single tensors. + try: + # First, try mapping to the expected input names (dictionary + # format) + input_dict = {} + if len(self.expected_inputs) == 1: + input_dict[self.expected_inputs[0]] = inputs[0] + else: + for i, input_name in enumerate(self.expected_inputs): + input_dict[input_name] = inputs[i] + + return self.keras_hub_model( + input_dict, training=training, mask=mask + ) + except ValueError as e: + error_msg = str(e) + # If that fails, try direct tensor input (positional format) + if ( + "doesn't match the expected structure" in error_msg + and "Expected: keras_tensor" in error_msg + ): + # The model expects a single tensor, not a dictionary + if len(inputs) == 1: + return self.keras_hub_model( + inputs[0], training=training, mask=mask + ) + else: + # Multiple inputs - try as positional arguments + return self.keras_hub_model( + *inputs, training=training, mask=mask + ) + elif "Missing data for input" in error_msg: + # Extract the actual expected input names from the error + if "Expected the following keys:" in error_msg: + # Parse the expected keys from error message + start = error_msg.find( + "Expected the following keys: [" + ) + if start != -1: + start += len("Expected the following keys: [") + end = error_msg.find("]", start) + if end != -1: + keys_str = error_msg[start:end] + actual_input_names = [ + k.strip().strip("'\"") + for k in keys_str.split(",") + ] + + # Map inputs to actual expected names + input_dict = {} + for i, actual_name in enumerate( + actual_input_names + ): + if i < len(inputs): + input_dict[actual_name] = inputs[i] + + return self.keras_hub_model( + input_dict, training=training, mask=mask + ) + + # If we still can't figure it out, re-raise the original + # error + raise + + def get_config(self): + """Return the configuration of the wrapped model.""" + return self.keras_hub_model.get_config() + + # Pass the correct parameter based on model type + is_text_model = self.config.MODEL_TYPE in [ + "causal_lm", + "text_classifier", + "seq2seq_lm", + ] + param = self.max_sequence_length if is_text_model else None + + return KerasHubModelWrapper( + self.model, + self.config.EXPECTED_INPUTS, + self.config.get_input_signature(param), + ) + + +# Convenience function for direct export +def export_litert(model, filepath, **kwargs): + """Export a Keras-Hub model to Litert format. + + This is a convenience function that automatically detects the model type + and exports it using the appropriate configuration. + + Args: + model: The Keras-Hub model to export + filepath: Path where to save the exported model (without extension) + **kwargs: Additional arguments passed to the exporter + """ + from keras_hub.src.export.base import ExporterRegistry + + # Get the appropriate configuration for this model + config = ExporterRegistry.get_config_for_model(model) + + # Create and use the Litert exporter + exporter = LitertExporter(config, **kwargs) + exporter.export(filepath) diff --git a/keras_hub/src/export/registry.py b/keras_hub/src/export/registry.py new file mode 100644 index 0000000000..246572bfbd --- /dev/null +++ b/keras_hub/src/export/registry.py @@ -0,0 +1,168 @@ +"""Registry initialization for Keras-Hub export functionality. + +This module initializes the export registry with available configurations and +exporters. +""" + +from keras_hub.src.export.base import ExporterRegistry +from keras_hub.src.export.configs import CausalLMExporterConfig +from keras_hub.src.export.configs import ImageClassifierExporterConfig +from keras_hub.src.export.configs import ImageSegmenterExporterConfig +from keras_hub.src.export.configs import ObjectDetectorExporterConfig +from keras_hub.src.export.configs import Seq2SeqLMExporterConfig +from keras_hub.src.export.configs import TextClassifierExporterConfig +from keras_hub.src.export.configs import TextModelExporterConfig + + +def initialize_export_registry(): + """Initialize the export registry with available configurations and + exporters.""" + # Register configurations for different model types + ExporterRegistry.register_config("causal_lm", CausalLMExporterConfig) + ExporterRegistry.register_config( + "text_classifier", TextClassifierExporterConfig + ) + ExporterRegistry.register_config("seq2seq_lm", Seq2SeqLMExporterConfig) + ExporterRegistry.register_config("text_model", TextModelExporterConfig) + + # Register vision model configurations + ExporterRegistry.register_config( + "image_classifier", ImageClassifierExporterConfig + ) + ExporterRegistry.register_config( + "object_detector", ObjectDetectorExporterConfig + ) + ExporterRegistry.register_config( + "image_segmenter", ImageSegmenterExporterConfig + ) + + # Register exporters for different formats + try: + from keras_hub.src.export.litert import LitertExporter + + ExporterRegistry.register_exporter("litert", LitertExporter) + except ImportError: + # Litert not available + pass + + +def export_model(model, filepath, format="litert", **kwargs): + """Export a Keras-Hub model to the specified format. + + This is the main export function that automatically detects the model type + and uses the appropriate exporter configuration. + + Args: + model: The Keras-Hub model to export + filepath: Path where to save the exported model (without extension) + format: Export format (currently supports "litert") + **kwargs: Additional arguments passed to the exporter + """ + # Registry is initialized at module level + config = ExporterRegistry.get_config_for_model(model) + + # Get the exporter for the specified format + exporter = ExporterRegistry.get_exporter(format, config, **kwargs) + + # Export the model + exporter.export(filepath) + + +def extend_export_method_for_keras_hub(): + """Extend the export method for Keras-Hub models to handle dictionary + inputs.""" + try: + import keras + + from keras_hub.src.models.task import Task + + # Store the original export method if it exists + original_export = getattr(Task, "export", None) or getattr( + keras.Model, "export", None + ) + + def keras_hub_export( + self, + filepath, + format="litert", + verbose=False, + **kwargs, + ): + """Extended export method for Keras-Hub models. + + This method extends Keras' export functionality to properly handle + Keras-Hub models that expect dictionary inputs. + + Args: + filepath: Path where to save the exported model (without + extension) + format: Export format. Supports "litert", "tf_saved_model", + etc. + verbose: Whether to print verbose output during export + **kwargs: Additional arguments passed to the exporter + """ + # Check if this is a Keras-Hub model that needs special handling + if format == "litert" and self._is_keras_hub_model(): + # Use our Keras-Hub specific export logic + kwargs["verbose"] = verbose + export_model(self, filepath, format=format, **kwargs) + else: + # Fall back to the original Keras export method + if original_export: + original_export( + self, filepath, format=format, verbose=verbose, **kwargs + ) + else: + raise NotImplementedError( + f"Export format '{format}' not supported for this " + "model type" + ) + + def _is_keras_hub_model(self): + """Check if this model is a Keras-Hub model that needs special + handling.""" + if hasattr(self, "__class__"): + class_name = self.__class__.__name__ + module_name = self.__class__.__module__ + + # Check if it's from keras_hub package + if "keras_hub" in module_name: + return True + + # Check if it has keras-hub specific attributes + if hasattr(self, "preprocessor") and hasattr(self, "backbone"): + return True + + # Check for common Keras-Hub model names + keras_hub_model_names = [ + "CausalLM", + "Seq2SeqLM", + "TextClassifier", + "ImageClassifier", + "ObjectDetector", + "ImageSegmenter", + ] + if any(name in class_name for name in keras_hub_model_names): + return True + + return False + + # Add the helper method and export method to the Task class + Task._is_keras_hub_model = _is_keras_hub_model + Task.export = keras_hub_export + + except ImportError: + # Task class not available, skip extension + pass + except Exception as e: + # Log error but don't fail import + import warnings + + warnings.warn( + f"Failed to extend export method for Keras-Hub models: {e}" + ) + + +# Initialize the registry when this module is imported +initialize_export_registry() +extend_export_method_for_keras_hub() diff --git a/keras_hub/src/models/__init__.py b/keras_hub/src/models/__init__.py index e69de29bb2..e993742347 100644 --- a/keras_hub/src/models/__init__.py +++ b/keras_hub/src/models/__init__.py @@ -0,0 +1,22 @@ +"""Import and initialize Keras-Hub export functionality. + +This module automatically extends Keras-Hub models with export capabilities +when imported. +""" + +import warnings + +# Import the export functionality +try: + from keras_hub.src.export.registry import extend_export_method_for_keras_hub + from keras_hub.src.export.registry import initialize_export_registry + + # Initialize export functionality + initialize_export_registry() + extend_export_method_for_keras_hub() +except ImportError as e: + warnings.warn( + f"Failed to import Keras-Hub export functionality: {e}", + ImportWarning, + stacklevel=2, + ) diff --git a/keras_hub/src/models/backbone.py b/keras_hub/src/models/backbone.py index 41bccf04b3..a43f9d2582 100644 --- a/keras_hub/src/models/backbone.py +++ b/keras_hub/src/models/backbone.py @@ -289,3 +289,106 @@ def export_to_transformers(self, path): ) export_backbone(self, path) + + def _get_save_spec(self, dynamic_batch=True): + """Compatibility shim for Keras/TensorFlow saving utilities. + + TensorFlow's SavedModel / TFLite export paths expect a + `_get_save_spec` method on subclassed models. In some runtime + combinations this method may not be present on the MRO for + our `Backbone` subclass; add a small shim that first delegates to + the superclass, and falls back to constructing simple + `tf.TensorSpec` objects from the functional `inputs` if needed. + + Args: + dynamic_batch: whether to set the batch dimension to `None`. + + Returns: + A TensorSpec, list or dict mirroring the model inputs, or + `None` when specs cannot be inferred. + """ + # Prefer the base implementation if available. + try: + return super()._get_save_spec(dynamic_batch) + except AttributeError: + # Fall back to building specs from `self.inputs`. + try: + from tensorflow.python.framework import tensor_spec + except (ImportError, ModuleNotFoundError): + return None + + inputs = getattr(self, "inputs", None) + if inputs is None: + return None + + def _make_spec(t): + # t is a tf.Tensor-like object + shape = list(t.shape) + if dynamic_batch and len(shape) > 0: + shape[0] = None + # Convert to tuple for TensorSpec + try: + name = getattr(t, "name", None) + return tensor_spec.TensorSpec( + shape=tuple(shape), dtype=t.dtype, name=name + ) + except (ImportError, ModuleNotFoundError): + return None + + # Handle dict/list/single tensor inputs + if isinstance(inputs, dict): + return {k: _make_spec(v) for k, v in inputs.items()} + if isinstance(inputs, (list, tuple)): + return [_make_spec(t) for t in inputs] + return _make_spec(inputs) + + def _trackable_children(self, save_type=None, **kwargs): + """Override to prevent _DictWrapper issues during TensorFlow export. + + This method filters out problematic _DictWrapper objects that cause + TypeError during SavedModel introspection, while preserving all + essential trackable components. + """ + children = super()._trackable_children(save_type, **kwargs) + + # Import _DictWrapper safely + try: + from tensorflow.python.trackable.data_structures import _DictWrapper + except ImportError: + return children + + clean_children = {} + for name, child in children.items(): + # Handle _DictWrapper objects + if isinstance(child, _DictWrapper): + try: + # For list-like _DictWrapper (e.g., transformer_layers) + if hasattr(child, "_data") and isinstance( + child._data, list + ): + # Create a clean list of the trackable items + clean_list = [] + for item in child._data: + if hasattr(item, "_trackable_children"): + clean_list.append(item) + if clean_list: + clean_children[name] = clean_list + # For dict-like _DictWrapper + elif hasattr(child, "_data") and isinstance( + child._data, dict + ): + clean_dict = {} + for k, v in child._data.items(): + if hasattr(v, "_trackable_children"): + clean_dict[k] = v + if clean_dict: + clean_children[name] = clean_dict + # Skip if we can't unwrap safely + except (AttributeError, TypeError): + # Skip problematic _DictWrapper objects + continue + else: + # Keep non-_DictWrapper children as-is + clean_children[name] = child + + return clean_children