-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Model Export to liteRT #21674
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
pctablet505
wants to merge
52
commits into
keras-team:master
Choose a base branch
from
pctablet505:export
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,033
−10
Open
Model Export to liteRT #21674
Changes from 34 commits
Commits
Show all changes
52 commits
Select commit
Hold shift + click to select a range
837506d
Add LiteRT (TFLite) export support to Keras
pctablet505 631850e
Update lite_rt_exporter.py
pctablet505 f5aa72e
Update export_utils.py
pctablet505 2b952d6
Refactor LiteRTExporter to simplify TFLite conversion
pctablet505 8f81dd5
Refactor import structure to avoid circular dependencies
pctablet505 011f1d8
trying kerashub
pctablet505 9a99a32
Enhance LiteRT export for sequence models and large models
pctablet505 d0070c6
Update lite_rt_exporter.py
pctablet505 761793f
Update lite_rt_exporter.py
pctablet505 7bb0506
Prevent tensor overflow for large vocabulary models
pctablet505 c219eb1
Update export_utils.py
pctablet505 e26ff6b
Update lite_rt_exporter.py
pctablet505 4a32e04
Simplify TFLite export and sequence length safety checks
pctablet505 3aca2f6
Merge branch 'keras-team:master' into export
pctablet505 926b0a8
Refactor TFLite export logic and add simple exporter
pctablet505 441a778
Merge branch 'export' of https://github.com/pctablet505/keras into ex…
pctablet505 4a8a9d5
Improve export robustness for large vocab and Keras-Hub models
pctablet505 f4b43b4
Update lite_rt_exporter.py
pctablet505 0fe4bd5
Update lite_rt_exporter.py
pctablet505 8c3faa3
Update lite_rt_exporter.py
pctablet505 88b6a6f
Update lite_rt_exporter.py
pctablet505 da13d04
Update lite_rt_exporter.py
pctablet505 f1f700c
Update lite_rt_exporter.py
pctablet505 5944780
Update lite_rt_exporter.py
pctablet505 4404c39
Update lite_rt_exporter.py
pctablet505 6a119fb
Update lite_rt_exporter.py
pctablet505 4cec7cd
Merge branch 'keras-team:master' into export
pctablet505 3a7fcc4
Merge branch 'keras-team:master' into export
pctablet505 51a1c7f
Remove sequence length bounding from export utils
pctablet505 e1fca24
Delete test_keras_hub_export.py
pctablet505 214558a
Merge branch 'keras-team:master' into export
pctablet505 73f00f1
Rename LiteRT exporter to Litert and update references
pctablet505 ebf11e2
Enhance LiteRT exporter and expand export tests
pctablet505 c6f0c70
Refactor LiteRT exporter to use module_utils.litert
pctablet505 3c1d90a
Simplify export_litert return value and messaging
pctablet505 657a271
Merge branch 'keras-team:master' into export
pctablet505 8ce8bfa
Merge branch 'export' of https://github.com/pctablet505/keras into ex…
pctablet505 cd9d063
Update export_utils.py
pctablet505 fa3d3ed
Refactor input signature inference for export
pctablet505 e775ff2
simplified code
pctablet505 34b662d
Refactor LiteRT exporter and update import paths
pctablet505 33b0550
Merge branch 'keras-team:master' into export
pctablet505 cbe0229
Refactor import statements for export_utils functions
pctablet505 e52de85
Update saved_model.py
pctablet505 87af9ed
Update litert.py
pctablet505 c643772
Add conditional TensorFlow import for LiteRT export
pctablet505 f243a6e
reformat
pctablet505 d8236fa
Update litert_test.py
pctablet505 83577be
Update litert_test.py
pctablet505 c53b264
Update litert_test.py
pctablet505 487184d
Update litert_test.py
pctablet505 374d90b
Update requirements-tensorflow-cuda.txt
pctablet505 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,14 @@ | |
|
||
|
||
def get_input_signature(model): | ||
"""Get input signature for model export. | ||
|
||
Args: | ||
model: A Keras Model instance. | ||
|
||
Returns: | ||
Input signature suitable for model export. | ||
""" | ||
if not isinstance(model, models.Model): | ||
raise TypeError( | ||
"The model must be a `keras.Model`. " | ||
|
@@ -17,19 +25,29 @@ def get_input_signature(model): | |
"The model provided has not yet been built. It must be built " | ||
"before export." | ||
) | ||
|
||
if isinstance(model, models.Functional): | ||
input_signature = [ | ||
tree.map_structure(make_input_spec, model._inputs_struct) | ||
] | ||
input_signature = tree.map_structure( | ||
make_input_spec, model._inputs_struct | ||
) | ||
elif isinstance(model, models.Sequential): | ||
input_signature = tree.map_structure(make_input_spec, model.inputs) | ||
else: | ||
# For subclassed models, try multiple approaches | ||
input_signature = _infer_input_signature_from_model(model) | ||
if not input_signature or not model._called: | ||
raise ValueError( | ||
"The model provided has never called. " | ||
"It must be called at least once before export." | ||
) | ||
if not input_signature: | ||
# Fallback: Try to get from model.inputs if available | ||
if hasattr(model, "inputs") and model.inputs: | ||
input_signature = tree.map_structure( | ||
make_input_spec, model.inputs | ||
) | ||
elif not model._called: | ||
raise ValueError( | ||
"The model provided has never been called and has no " | ||
"detectable input structure. It must be called at least " | ||
"once before export, or you must provide explicit " | ||
"input_signature." | ||
) | ||
return input_signature | ||
|
||
|
||
|
@@ -45,22 +63,56 @@ def _make_input_spec(structure): | |
return {k: _make_input_spec(v) for k, v in structure.items()} | ||
elif isinstance(structure, tuple): | ||
if all(isinstance(d, (int, type(None))) for d in structure): | ||
# Keep batch dimension unbounded, keep other dimensions as they | ||
# are | ||
bounded_shape = [] | ||
|
||
for i, dim in enumerate(structure): | ||
if dim is None and i == 0: | ||
# Always keep batch dimension as None | ||
bounded_shape.append(None) | ||
else: | ||
# Keep other dimensions as they are (None or specific | ||
# size) | ||
bounded_shape.append(dim) | ||
|
||
return layers.InputSpec( | ||
shape=(None,) + structure[1:], dtype=model.input_dtype | ||
shape=tuple(bounded_shape), dtype=model.input_dtype | ||
) | ||
pctablet505 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
return tuple(_make_input_spec(v) for v in structure) | ||
elif isinstance(structure, list): | ||
if all(isinstance(d, (int, type(None))) for d in structure): | ||
# Keep batch dimension unbounded, keep other dimensions as they | ||
# are | ||
bounded_shape = [] | ||
|
||
for i, dim in enumerate(structure): | ||
if dim is None and i == 0: | ||
# Always keep batch dimension as None | ||
bounded_shape.append(None) | ||
else: | ||
# Keep other dimensions as they are | ||
bounded_shape.append(dim) | ||
|
||
return layers.InputSpec( | ||
shape=[None] + structure[1:], dtype=model.input_dtype | ||
shape=bounded_shape, dtype=model.input_dtype | ||
) | ||
return [_make_input_spec(v) for v in structure] | ||
else: | ||
raise ValueError( | ||
f"Unsupported type {type(structure)} for {structure}" | ||
) | ||
|
||
return [_make_input_spec(value) for value in shapes_dict.values()] | ||
# Try to reconstruct the input structure from build shapes | ||
if len(shapes_dict) == 1: | ||
# Single input case | ||
return _make_input_spec(list(shapes_dict.values())[0]) | ||
else: | ||
# Multiple inputs - try to determine if it's a dict or list structure | ||
# Return as dictionary by default to preserve input names | ||
return { | ||
key: _make_input_spec(shape) for key, shape in shapes_dict.items() | ||
} | ||
|
||
|
||
def make_input_spec(x): | ||
|
@@ -105,3 +157,38 @@ def convert_spec_to_tensor(spec, replace_none_number=None): | |
s if s is not None else replace_none_number for s in shape | ||
) | ||
return ops.ones(shape, spec.dtype) | ||
|
||
|
||
# Registry for export formats | ||
|
||
EXPORT_FORMATS = { | ||
"tf_saved_model": "keras.src.export.saved_model:export_saved_model", | ||
"litert": "keras.src.export.litert_exporter:export_litert", | ||
# Add other formats as needed | ||
} | ||
|
||
|
||
def _get_exporter(format_name): | ||
"""Lazy import exporter to avoid circular imports.""" | ||
if format_name not in EXPORT_FORMATS: | ||
raise ValueError(f"Unknown export format: {format_name}") | ||
|
||
exporter = EXPORT_FORMATS[format_name] | ||
if isinstance(exporter, str): | ||
# Lazy import for string references | ||
module_path, attr_name = exporter.split(":") | ||
module = __import__(module_path, fromlist=[attr_name]) | ||
return getattr(module, attr_name) | ||
else: | ||
# Direct reference | ||
return exporter | ||
|
||
|
||
def export_model(model, filepath, format="tf_saved_model", **kwargs): | ||
"""Export a model to the specified format.""" | ||
exporter = _get_exporter(format) | ||
|
||
if isinstance(exporter, type): | ||
exporter_instance = exporter(model, **kwargs) | ||
return exporter_instance.export(filepath) | ||
|
||
return exporter(model, filepath, **kwargs) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto here, what's the case you are trying to address?