-
Notifications
You must be signed in to change notification settings - Fork 301
Model Export to liteRT #2405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
pctablet505
wants to merge
23
commits into
keras-team:master
Choose a base branch
from
pctablet505:export
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Model Export to liteRT #2405
Changes from 11 commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
087b9b2
Update backbone.py
pctablet505 de830b1
Update backbone.py
pctablet505 62d2484
Update task.py
pctablet505 3b71125
Revert "Update task.py"
pctablet505 3d453ff
Revert "Update backbone.py"
pctablet505 92b1254
export
pctablet505 e46241d
refactoring
pctablet505 6e970e2
refactor
pctablet505 15ad9f3
Update registry.py
pctablet505 02ca0d9
Refactor export logic and improve error handling
pctablet505 901c233
Merge branch 'keras-team:master' into export
pctablet505 442fdd3
reformat
pctablet505 5446e2a
Add export submodule to keras_hub API
pctablet505 5c31d88
reformat
pctablet505 3290d42
now supporting export for objectDetectors
pctablet505 8b1024f
Add and refine image model exporter configs
pctablet505 8df5a75
Refactor: move keras import to module level
pctablet505 759d223
Remove debug_object_detection.py script
pctablet505 0737c93
Rename LiteRT to Litert and update exporter configs
pctablet505 c733e18
Refactor InputSpec formatting and fix import path
pctablet505 5ab911f
Refactor exporter configs and model building logic
pctablet505 c1e26dd
Refactor export initialization and improve warnings
pctablet505 6fa8379
Improve dtype handling and verbose output in exporters
pctablet505 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from keras_hub.src.export.base import ExporterRegistry | ||
from keras_hub.src.export.base import KerasHubExporter | ||
from keras_hub.src.export.base import KerasHubExporterConfig | ||
from keras_hub.src.export.configs import CausalLMExporterConfig | ||
from keras_hub.src.export.configs import Seq2SeqLMExporterConfig | ||
from keras_hub.src.export.configs import TextClassifierExporterConfig | ||
from keras_hub.src.export.configs import TextModelExporterConfig | ||
from keras_hub.src.export.lite_rt import export_lite_rt | ||
from keras_hub.src.export.lite_rt import LiteRTExporter |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,252 @@ | ||
"""Base classes for Keras-Hub model exporters. | ||
|
||
This module provides the foundation for exporting Keras-Hub models to various formats. | ||
It follows the Optimum pattern of having different exporters for different model types and formats. | ||
""" | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Dict, Any, Optional, List, Type | ||
|
||
try: | ||
import keras | ||
from keras.src.export.export_utils import get_input_signature | ||
KERAS_AVAILABLE = True | ||
except ImportError: | ||
KERAS_AVAILABLE = False | ||
keras = None | ||
|
||
|
||
class KerasHubExporterConfig(ABC): | ||
"""Base configuration class for Keras-Hub model exporters. | ||
|
||
This class defines the interface for exporter configurations that specify | ||
how different types of Keras-Hub models should be exported. | ||
""" | ||
|
||
# Model type this exporter handles (e.g., "causal_lm", "text_classifier") | ||
MODEL_TYPE: str = None | ||
|
||
# Expected input structure for this model type | ||
EXPECTED_INPUTS: List[str] = [] | ||
|
||
# Default sequence length if not specified | ||
DEFAULT_SEQUENCE_LENGTH: int = 128 | ||
|
||
def __init__(self, model, **kwargs): | ||
"""Initialize the exporter configuration. | ||
|
||
Args: | ||
model: The Keras-Hub model to export | ||
**kwargs: Additional configuration parameters | ||
""" | ||
self.model = model | ||
self.config_kwargs = kwargs | ||
self._validate_model() | ||
|
||
def _validate_model(self): | ||
"""Validate that the model is compatible with this exporter.""" | ||
if not self._is_model_compatible(): | ||
raise ValueError( | ||
f"Model {self.model.__class__.__name__} is not compatible " | ||
f"with {self.__class__.__name__}" | ||
) | ||
|
||
@abstractmethod | ||
def _is_model_compatible(self) -> bool: | ||
"""Check if the model is compatible with this exporter.""" | ||
pass | ||
|
||
@abstractmethod | ||
def get_input_signature(self, sequence_length: Optional[int] = None) -> Dict[str, Any]: | ||
"""Get the input signature for this model type. | ||
|
||
Args: | ||
sequence_length: Optional sequence length for input tensors | ||
|
||
Returns: | ||
Dictionary mapping input names to their signatures | ||
""" | ||
pass | ||
|
||
def get_dummy_inputs(self, sequence_length: Optional[int] = None) -> Dict[str, Any]: | ||
"""Generate dummy inputs for model building and testing. | ||
|
||
Args: | ||
sequence_length: Optional sequence length for dummy inputs | ||
|
||
Returns: | ||
Dictionary of dummy inputs | ||
""" | ||
if sequence_length is None: | ||
sequence_length = self.DEFAULT_SEQUENCE_LENGTH | ||
|
||
dummy_inputs = {} | ||
|
||
# Common inputs for most Keras-Hub models | ||
if "token_ids" in self.EXPECTED_INPUTS: | ||
dummy_inputs["token_ids"] = keras.ops.ones((1, sequence_length), dtype='int32') | ||
if "padding_mask" in self.EXPECTED_INPUTS: | ||
dummy_inputs["padding_mask"] = keras.ops.ones((1, sequence_length), dtype='bool') | ||
|
||
# Encoder-decoder specific inputs | ||
if "encoder_token_ids" in self.EXPECTED_INPUTS: | ||
dummy_inputs["encoder_token_ids"] = keras.ops.ones((1, sequence_length), dtype='int32') | ||
if "encoder_padding_mask" in self.EXPECTED_INPUTS: | ||
dummy_inputs["encoder_padding_mask"] = keras.ops.ones((1, sequence_length), dtype='bool') | ||
if "decoder_token_ids" in self.EXPECTED_INPUTS: | ||
dummy_inputs["decoder_token_ids"] = keras.ops.ones((1, sequence_length), dtype='int32') | ||
if "decoder_padding_mask" in self.EXPECTED_INPUTS: | ||
dummy_inputs["decoder_padding_mask"] = keras.ops.ones((1, sequence_length), dtype='bool') | ||
|
||
return dummy_inputs | ||
|
||
|
||
class KerasHubExporter(ABC): | ||
"""Base class for Keras-Hub model exporters. | ||
|
||
This class provides the common interface for exporting Keras-Hub models | ||
to different formats (LiteRT, ONNX, etc.). | ||
""" | ||
|
||
def __init__(self, config: KerasHubExporterConfig, **kwargs): | ||
"""Initialize the exporter. | ||
|
||
Args: | ||
config: Exporter configuration specifying model type and parameters | ||
**kwargs: Additional exporter-specific parameters | ||
""" | ||
self.config = config | ||
self.model = config.model | ||
self.export_kwargs = kwargs | ||
|
||
@abstractmethod | ||
def export(self, filepath: str) -> None: | ||
"""Export the model to the specified filepath. | ||
|
||
Args: | ||
filepath: Path where to save the exported model | ||
""" | ||
pass | ||
|
||
def _ensure_model_built(self, sequence_length: Optional[int] = None) -> None: | ||
"""Ensure the model is properly built with correct input structure. | ||
|
||
Args: | ||
sequence_length: Optional sequence length for dummy inputs | ||
""" | ||
if not self.model.built: | ||
dummy_inputs = self.config.get_dummy_inputs(sequence_length) | ||
|
||
try: | ||
# Build the model with the correct input structure | ||
_ = self.model(dummy_inputs, training=False) | ||
except Exception as e: | ||
# Try alternative approach using build() method | ||
try: | ||
input_shapes = {key: tensor.shape for key, tensor in dummy_inputs.items()} | ||
self.model.build(input_shape=input_shapes) | ||
except Exception: | ||
raise ValueError( | ||
f"Failed to build model: {e}. Please ensure the model is properly constructed." | ||
) | ||
|
||
|
||
class ExporterRegistry: | ||
"""Registry for mapping model types to their appropriate exporters.""" | ||
|
||
_configs = {} | ||
_exporters = {} | ||
|
||
@classmethod | ||
def register_config(cls, model_type: str, config_class: Type[KerasHubExporterConfig]) -> None: | ||
"""Register a configuration class for a model type. | ||
|
||
Args: | ||
model_type: The model type (e.g., "causal_lm") | ||
config_class: The configuration class | ||
""" | ||
cls._configs[model_type] = config_class | ||
|
||
@classmethod | ||
def register_exporter(cls, format_name: str, exporter_class: Type[KerasHubExporter]) -> None: | ||
"""Register an exporter class for a format. | ||
|
||
Args: | ||
format_name: The export format (e.g., "lite_rt") | ||
exporter_class: The exporter class | ||
""" | ||
cls._exporters[format_name] = exporter_class | ||
|
||
@classmethod | ||
def get_config_for_model(cls, model) -> KerasHubExporterConfig: | ||
"""Get the appropriate configuration for a model. | ||
|
||
Args: | ||
model: The Keras-Hub model | ||
|
||
Returns: | ||
An appropriate exporter configuration instance | ||
|
||
Raises: | ||
ValueError: If no configuration is found for the model type | ||
""" | ||
model_type = cls._detect_model_type(model) | ||
|
||
if model_type not in cls._configs: | ||
raise ValueError(f"No configuration found for model type: {model_type}") | ||
|
||
config_class = cls._configs[model_type] | ||
return config_class(model) | ||
|
||
@classmethod | ||
def get_exporter(cls, format_name: str, config: KerasHubExporterConfig, **kwargs) -> KerasHubExporter: | ||
"""Get an exporter for the specified format. | ||
|
||
Args: | ||
format_name: The export format | ||
config: The exporter configuration | ||
**kwargs: Additional parameters for the exporter | ||
|
||
Returns: | ||
An appropriate exporter instance | ||
|
||
Raises: | ||
ValueError: If no exporter is found for the format | ||
""" | ||
if format_name not in cls._exporters: | ||
raise ValueError(f"No exporter found for format: {format_name}") | ||
|
||
exporter_class = cls._exporters[format_name] | ||
return exporter_class(config, **kwargs) | ||
|
||
@classmethod | ||
def _detect_model_type(cls, model) -> str: | ||
"""Detect the model type from the model instance. | ||
|
||
Args: | ||
model: The Keras-Hub model | ||
|
||
Returns: | ||
The detected model type | ||
""" | ||
# Import here to avoid circular imports | ||
try: | ||
from keras_hub.src.models.causal_lm import CausalLM | ||
from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM | ||
except ImportError: | ||
CausalLM = None | ||
Seq2SeqLM = None | ||
|
||
model_class_name = model.__class__.__name__ | ||
|
||
if CausalLM and isinstance(model, CausalLM): | ||
return "causal_lm" | ||
elif 'TextClassifier' in model_class_name: | ||
return "text_classifier" | ||
elif Seq2SeqLM and isinstance(model, Seq2SeqLM): | ||
return "seq2seq_lm" | ||
elif 'ImageClassifier' in model_class_name: | ||
return "image_classifier" | ||
else: | ||
# Default to text model for generic Keras-Hub models | ||
return "text_model" |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.