diff --git a/keras/src/export/__init__.py b/keras/src/export/__init__.py index 7adfd18513f..8782bca44ce 100644 --- a/keras/src/export/__init__.py +++ b/keras/src/export/__init__.py @@ -3,3 +3,12 @@ from keras.src.export.saved_model import ExportArchive from keras.src.export.saved_model import export_saved_model from keras.src.export.tfsm_layer import TFSMLayer + +# LiteRT export requires TensorFlow, so we import conditionally +try: + from keras.src.export.litert import LitertExporter + from keras.src.export.litert import export_litert +except ImportError: + # TensorFlow not available, LiteRT export will not be available + LitertExporter = None + export_litert = None diff --git a/keras/src/export/export_utils.py b/keras/src/export/export_utils.py index 4b76f68fe4a..641e03e2335 100644 --- a/keras/src/export/export_utils.py +++ b/keras/src/export/export_utils.py @@ -7,6 +7,14 @@ def get_input_signature(model): + """Get input signature for model export. + + Args: + model: A Keras Model instance. + + Returns: + Input signature suitable for model export (always a tuple or list). + """ if not isinstance(model, models.Model): raise TypeError( "The model must be a `keras.Model`. " @@ -17,18 +25,25 @@ def get_input_signature(model): "The model provided has not yet been built. It must be built " "before export." ) + if isinstance(model, models.Functional): + # Functional models expect a single positional argument `inputs` + # containing the full nested input structure. We keep the + # original behavior of returning a single-element list that + # wraps the mapped structure so that downstream exporters + # build a tf.function with one positional argument. input_signature = [ tree.map_structure(make_input_spec, model._inputs_struct) ] elif isinstance(model, models.Sequential): input_signature = tree.map_structure(make_input_spec, model.inputs) else: + # Subclassed models: rely on recorded shapes from the first call. input_signature = _infer_input_signature_from_model(model) if not input_signature or not model._called: raise ValueError( - "The model provided has never called. " - "It must be called at least once before export." + "The model provided has never called. It must be called " + "at least once before export." ) return input_signature @@ -41,25 +56,36 @@ def _infer_input_signature_from_model(model): def _make_input_spec(structure): # We need to turn wrapper structures like TrackingDict or _DictWrapper # into plain Python structures because they don't work with jax2tf/JAX. + if structure is None: + return None if isinstance(structure, dict): return {k: _make_input_spec(v) for k, v in structure.items()} elif isinstance(structure, tuple): if all(isinstance(d, (int, type(None))) for d in structure): - return layers.InputSpec( - shape=(None,) + structure[1:], dtype=model.input_dtype + # For export, force batch dimension to None for flexible + # batching + shape = ( + (None,) + structure[1:] if len(structure) > 0 else structure ) + return layers.InputSpec(shape=shape, dtype=model.input_dtype) return tuple(_make_input_spec(v) for v in structure) elif isinstance(structure, list): if all(isinstance(d, (int, type(None))) for d in structure): - return layers.InputSpec( - shape=[None] + structure[1:], dtype=model.input_dtype + # For export, force batch dimension to None for flexible + # batching + shape = ( + (None,) + tuple(structure[1:]) + if len(structure) > 0 + else tuple(structure) ) + return layers.InputSpec(shape=shape, dtype=model.input_dtype) return [_make_input_spec(v) for v in structure] else: raise ValueError( f"Unsupported type {type(structure)} for {structure}" ) + # Always return a flat list preserving the order of shapes_dict values return [_make_input_spec(value) for value in shapes_dict.values()] diff --git a/keras/src/export/litert.py b/keras/src/export/litert.py new file mode 100644 index 00000000000..398be83b6ee --- /dev/null +++ b/keras/src/export/litert.py @@ -0,0 +1,474 @@ +import os + +import tensorflow as tf + +from keras.src import tree +from keras.src.utils import io_utils +from keras.src.utils.module_utils import litert + + +def export_litert( + model, + filepath, + verbose=None, + input_signature=None, + aot_compile_targets=None, + **kwargs, +): + """Export the model as a Litert artifact for inference. + + Args: + model: The Keras model to export. + filepath: The path to save the exported artifact. + verbose: `bool`. Whether to print a message during export. Defaults to + `None`, which uses the default value set by different backends and + formats. + input_signature: Optional input signature specification. If + ``None``, it will be inferred. + aot_compile_targets: Optional list of Litert targets for AOT + compilation. + **kwargs: Additional keyword arguments passed to the exporter. + """ + + if verbose is None: + verbose = True # Defaults to `True` for all backends. + + exporter = LitertExporter( + model=model, + input_signature=input_signature, + verbose=verbose, + aot_compile_targets=aot_compile_targets, + **kwargs, + ) + exporter.export(filepath) + if verbose: + io_utils.print_msg(f"Saved artifact at '{filepath}'.") + + +class LitertExporter: + """ + Exporter for the Litert (TFLite) format that creates a single, + callable signature for `model.call`. + """ + + def __init__( + self, + model, + input_signature=None, + verbose=False, + aot_compile_targets=None, + **kwargs, + ): + """Initialize the Litert exporter. + + Args: + model: The Keras model to export + input_signature: Input signature specification + verbose: Whether to print progress messages during export. + aot_compile_targets: List of Litert targets for AOT compilation + **kwargs: Additional export parameters + """ + self.model = model + self.input_signature = input_signature + self.verbose = verbose + self.aot_compile_targets = aot_compile_targets + self.kwargs = kwargs + + def export(self, filepath): + """Exports the Keras model to a TFLite file and optionally performs AOT + compilation. + + Args: + filepath: Output path for the exported model + + Returns: + Path to exported model or compiled models if AOT compilation is + performed + """ + if self.verbose: + print("Starting Litert export...") + + # 1. Ensure the model is built by calling it if necessary + self._ensure_model_built() + + # 2. Resolve / infer input signature + if self.input_signature is None: + if self.verbose: + print("Inferring input signature from model.") + from keras.src.export.export_utils import get_input_signature + + self.input_signature = get_input_signature(self.model) + + # 3. Convert the model to TFLite. + tflite_model = self._convert_to_tflite(self.input_signature) + + if self.verbose: + final_size_mb = len(tflite_model) / (1024 * 1024) + print( + f"TFLite model converted successfully. Size: " + f"{final_size_mb:.2f} MB" + ) + + # 4. Save the initial TFLite model to the specified file path. + assert filepath.endswith(".tflite"), ( + "The LiteRT export requires the filepath to end with '.tflite'. " + f"Got: {filepath}" + ) + + with open(filepath, "wb") as f: + f.write(tflite_model) + + if self.verbose: + print(f"TFLite model saved to {filepath}") + + # 5. Perform AOT compilation if targets are specified and LiteRT is + # available + compiled_models = None + if self.aot_compile_targets and litert.available: + if self.verbose: + print("Performing AOT compilation for Litert targets...") + compiled_models = self._aot_compile(filepath) + elif self.aot_compile_targets and not litert.available: + if self.verbose: + print( + "Warning: AOT compilation requested but LiteRT is not " + "available. Skipping." + ) + + if self.verbose: + print(f"Litert export completed. Base model: {filepath}") + if compiled_models: + print( + f"AOT compiled models: {len(compiled_models.models)} " + "variants" + ) + + return compiled_models if compiled_models else filepath + + def _ensure_model_built(self): + """ + Ensures the model is built before conversion. + + For models that are not yet built, this attempts to build them + using the input signature or model.inputs. + """ + if self.model.built: + return + + if self.verbose: + print("Building model before conversion...") + + try: + # Try to build using input_signature if available + if self.input_signature: + input_shapes = tree.map_structure( + lambda spec: spec.shape, self.input_signature + ) + self.model.build(input_shapes) + # Fall back to model.inputs for Functional/Sequential models + elif hasattr(self.model, "inputs") and self.model.inputs: + input_shapes = [inp.shape for inp in self.model.inputs] + if len(input_shapes) == 1: + self.model.build(input_shapes[0]) + else: + self.model.build(input_shapes) + else: + raise ValueError( + "Cannot build model: no input_signature provided and " + "model has no inputs attribute. Please provide an " + "input_signature or ensure the model is already built." + ) + + if self.verbose: + print("Model built successfully.") + + except Exception as e: + if self.verbose: + print(f"Error building model: {e}") + raise ValueError( + f"Failed to build model: {e}. Please ensure the model is " + "properly defined or provide an input_signature." + ) + + def _convert_to_tflite(self, input_signature): + """Converts the Keras model to a TFLite model.""" + is_sequential = isinstance(self.model, tf.keras.Sequential) + + # Try direct conversion first for all models + try: + if self.verbose: + model_type = "Sequential" if is_sequential else "Functional" + print( + f"{model_type} model detected. Trying direct conversion..." + ) + + converter = tf.lite.TFLiteConverter.from_keras_model(self.model) + converter.target_spec.supported_ops = [ + tf.lite.OpsSet.TFLITE_BUILTINS, + tf.lite.OpsSet.SELECT_TF_OPS, + ] + converter.experimental_enable_resource_variables = False + tflite_model = converter.convert() + + if self.verbose: + print("Direct conversion successful.") + return tflite_model + + except Exception as direct_error: + if self.verbose: + model_type = "Sequential" if is_sequential else "Functional" + print( + f"Direct conversion failed for {model_type} model: " + f"{direct_error}" + ) + print("Falling back to wrapper-based conversion...") + + return self._convert_with_wrapper(input_signature) + + def _convert_with_wrapper(self, input_signature): + """Converts the model to TFLite using the tf.Module wrapper.""" + # 1. Wrap the Keras model in our clean tf.Module. + wrapper = _KerasModelWrapper(self.model) + + # 2. Get a concrete function from the wrapper. + if not isinstance(input_signature, (list, tuple)): + input_signature = [input_signature] + + from keras.src.export.export_utils import make_tf_tensor_spec + + tensor_specs = [make_tf_tensor_spec(spec) for spec in input_signature] + + # Pass tensor specs as positional arguments to get the concrete + # function. + concrete_func = wrapper.__call__.get_concrete_function(*tensor_specs) + + # 3. Convert from the concrete function. + if self.verbose: + print("Converting concrete function to TFLite format...") + + # Try multiple conversion strategies for better inference compatibility + conversion_strategies = [ + { + "experimental_enable_resource_variables": False, + "name": "without resource variables", + }, + { + "experimental_enable_resource_variables": True, + "name": "with resource variables", + }, + ] + + for strategy in conversion_strategies: + try: + converter = tf.lite.TFLiteConverter.from_concrete_functions( + [concrete_func], trackable_obj=wrapper + ) + converter.target_spec.supported_ops = [ + tf.lite.OpsSet.TFLITE_BUILTINS, + tf.lite.OpsSet.SELECT_TF_OPS, + ] + converter.experimental_enable_resource_variables = strategy[ + "experimental_enable_resource_variables" + ] + + if self.verbose: + print(f"Trying conversion {strategy['name']}...") + + tflite_model = converter.convert() + + if self.verbose: + print(f"Conversion successful {strategy['name']}!") + + return tflite_model + + except Exception as e: + if self.verbose: + print(f"Conversion failed {strategy['name']}: {e}") + continue + + # If all strategies fail, raise the last error + raise RuntimeError( + "All conversion strategies failed for wrapper-based conversion" + ) + + def _aot_compile(self, tflite_filepath): + """Performs AOT compilation using LiteRT.""" + if not litert.available: + raise RuntimeError("LiteRT is not available for AOT compilation") + + try: + # Create a LiteRT model from the TFLite file + litert_model = litert.python.aot.core.types.Model.create_from_path( + tflite_filepath + ) + + # Determine output directory + base_dir = os.path.dirname(tflite_filepath) + model_name = os.path.splitext(os.path.basename(tflite_filepath))[0] + output_dir = os.path.join(base_dir, f"{model_name}_compiled") + + if self.verbose: + print(f"AOT compiling for targets: {self.aot_compile_targets}") + print(f"Output directory: {output_dir}") + + # Perform AOT compilation + result = litert.python.aot.aot_compile( + input_model=litert_model, + output_dir=output_dir, + target=self.aot_compile_targets, + keep_going=True, # Continue even if some targets fail + ) + + if self.verbose: + print( + f"AOT compilation completed: {len(result.models)} " + f"successful, {len(result.failed_backends)} failed" + ) + if result.failed_backends: + for backend, error in result.failed_backends: + print(f" Failed: {backend.id()} - {error}") + + # Print compilation report if available + try: + report = result.compilation_report() + if report: + print("Compilation Report:") + print(report) + except Exception: + pass + + return result + + except Exception as e: + if self.verbose: + print(f"AOT compilation failed: {e}") + import traceback + + traceback.print_exc() + raise RuntimeError(f"AOT compilation failed: {e}") + + def _get_available_litert_targets(self): + """Get available LiteRT targets for AOT compilation.""" + if not litert.available: + return [] + + try: + # Get all registered targets + targets = ( + litert.python.aot.vendors.import_vendor.AllRegisteredTarget() + ) + return targets if isinstance(targets, list) else [targets] + except Exception as e: + if self.verbose: + print(f"Failed to get available targets: {e}") + return [] + + @classmethod + def export_with_aot( + cls, model, filepath, targets=None, verbose=True, **kwargs + ): + """ + Convenience method to export a Keras model with AOT compilation. + + Args: + model: Keras model to export + filepath: Output file path + targets: List of LiteRT targets for AOT compilation (e.g., + ['qualcomm', 'mediatek']) + verbose: Whether to print verbose output + **kwargs: Additional arguments for the exporter + + Returns: + CompilationResult if AOT compilation is performed, otherwise the + filepath + """ + exporter = cls( + model=model, verbose=verbose, aot_compile_targets=targets, **kwargs + ) + return exporter.export(filepath) + + @classmethod + def get_available_targets(cls): + """Get list of available LiteRT AOT compilation targets.""" + if not litert.available: + return [] + + dummy_exporter = cls(model=None) + return dummy_exporter._get_available_litert_targets() + + +class _KerasModelWrapper(tf.Module): + """ + A tf.Module wrapper for a Keras model. + + This wrapper is designed to be a clean, serializable interface for TFLite + conversion. It holds the Keras model and exposes a single `__call__` + method that is decorated with `tf.function`. Crucially, it also ensures + all variables from the Keras model are tracked by the SavedModel format, + which is key to including them in the final TFLite model. + """ + + def __init__(self, model): + super().__init__() + # Store the model reference in a way that TensorFlow won't try to + # track it + # This prevents the _DictWrapper error during SavedModel serialization + object.__setattr__(self, "_model", model) + + # Track all variables from the Keras model using proper tf.Module + # methods + # This ensures proper variable handling for stateful layers like + # BatchNorm + with self.name_scope: + for i, var in enumerate(model.variables): + # Use a different attribute name to avoid conflicts with + # tf.Module's variables property + setattr(self, f"model_var_{i}", var) + + @tf.function + def __call__(self, *args, **kwargs): + """The single entry point for the exported model.""" + # Handle both single and multi-input cases + if args and not kwargs: + # Called with positional arguments + if len(args) == 1: + return self._model(args[0]) + else: + return self._model(list(args)) + elif kwargs and not args: + # Called with keyword arguments + if len(kwargs) == 1 and "inputs" in kwargs: + # Single input case + return self._model(kwargs["inputs"]) + else: + # Multi-input case - convert to list/dict format expected by + # model + if ( + hasattr(self._model, "inputs") + and len(self._model.inputs) > 1 + ): + # Multi-input functional model + input_list = [] + missing_inputs = [] + for input_layer in self._model.inputs: + input_name = input_layer.name + if input_name in kwargs: + input_list.append(kwargs[input_name]) + else: + missing_inputs.append(input_name) + + if missing_inputs: + raise ValueError( + f"Missing required inputs for multi-input model: " + f"{missing_inputs}. Available kwargs: " + f"{list(kwargs.keys())}. Please provide all inputs " + f"by name." + ) + + return self._model(input_list) + else: + # Single input model called with named arguments + return self._model(list(kwargs.values())[0]) + else: + # Fallback to original call + return self._model(*args, **kwargs) diff --git a/keras/src/export/litert_test.py b/keras/src/export/litert_test.py new file mode 100644 index 00000000000..926a91e1831 --- /dev/null +++ b/keras/src/export/litert_test.py @@ -0,0 +1,467 @@ +import os + +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import ops +from keras.src import testing +from keras.src import tree +from keras.src.saving import saving_lib +from keras.src.testing.test_utils import named_product +from keras.src.utils.module_utils import litert +from keras.src.utils.module_utils import tensorflow + +# Set up LiteRT interpreter with fallback logic: +# 1. Try AI Edge LiteRT interpreter (preferred) +# 2. Fall back to TensorFlow Lite interpreter if AI Edge LiteRT unavailable +AI_EDGE_LITERT_AVAILABLE = False +LiteRtInterpreter = None + +if litert.available: + try: + from ai_edge_litert.interpreter import Interpreter as LiteRtInterpreter + + AI_EDGE_LITERT_AVAILABLE = True + except ImportError: + # Fallback to TensorFlow Lite interpreter if AI Edge LiteRT unavailable + if tensorflow.available: + LiteRtInterpreter = tensorflow.lite.Interpreter +else: + # Fallback to TensorFlow Lite interpreter if AI Edge LiteRT unavailable + if tensorflow.available: + LiteRtInterpreter = tensorflow.lite.Interpreter + +# Model types to test (LSTM only if AI Edge LiteRT is available) +model_types = ["sequential", "functional"] +if AI_EDGE_LITERT_AVAILABLE: + model_types.append("lstm") + + +class CustomModel(models.Model): + def __init__(self, layer_list): + super().__init__() + self.layer_list = layer_list + + def call(self, input): + output = input + for layer in self.layer_list: + output = layer(output) + return output + + +def get_model(type="sequential", input_shape=(10,), layer_list=None): + layer_list = layer_list or [ + layers.Dense(10, activation="relu"), + layers.BatchNormalization(), + layers.Dense(1, activation="sigmoid"), + ] + if type == "sequential": + model = models.Sequential(layer_list) + model.build(input_shape=(None,) + input_shape) + return model + if type == "functional": + input = output = tree.map_shape_structure(layers.Input, input_shape) + for layer in layer_list: + output = layer(output) + return models.Model(inputs=input, outputs=output) + if type == "subclass": + model = CustomModel(layer_list) + model.build(input_shape=(None,) + input_shape) + # Trace the model with dummy data to ensure it's properly built for + # export + dummy_input = np.zeros((1,) + input_shape, dtype=np.float32) + _ = model(dummy_input) # This traces the model + return model + if type == "lstm": + inputs = layers.Input((4, 10)) + x = layers.Bidirectional( + layers.LSTM( + 10, + kernel_initializer="he_normal", + return_sequences=True, + kernel_regularizer=None, + ), + merge_mode="sum", + )(inputs) + outputs = layers.Bidirectional( + layers.LSTM( + 10, + kernel_initializer="he_normal", + return_sequences=True, + kernel_regularizer=None, + ), + merge_mode="concat", + )(x) + return models.Model(inputs=inputs, outputs=outputs) + if type == "multi_input": + input1 = layers.Input(shape=input_shape, name="input1") + input2 = layers.Input(shape=input_shape, name="input2") + x1 = layers.Dense(10, activation="relu")(input1) + x2 = layers.Dense(10, activation="relu")(input2) + combined = layers.concatenate([x1, x2]) + output = layers.Dense(1, activation="sigmoid")(combined) + return models.Model(inputs=[input1, input2], outputs=output) + if type == "multi_output": + inputs = layers.Input(shape=input_shape) + shared = layers.Dense(20, activation="relu")(inputs) + output1 = layers.Dense(1, activation="sigmoid", name="output1")(shared) + output2 = layers.Dense(3, activation="softmax", name="output2")(shared) + return models.Model(inputs=inputs, outputs=[output1, output2]) + raise ValueError(f"Unknown model type: {type}") + + +def _convert_to_numpy(structure): + return tree.map_structure( + lambda x: x.numpy() if hasattr(x, "numpy") else np.array(x), structure + ) + + +def _normalize_name(name): + normalized = name.split(":")[0] + if normalized.startswith("serving_default_"): + normalized = normalized[len("serving_default_") :] + return normalized + + +def _set_interpreter_inputs(interpreter, inputs): + input_details = interpreter.get_input_details() + if isinstance(inputs, dict): + for detail in input_details: + key = _normalize_name(detail["name"]) + if key in inputs: + value = inputs[key] + else: + matched_key = None + for candidate in inputs: + if key.endswith(candidate) or candidate.endswith(key): + matched_key = candidate + break + if matched_key is None: + raise KeyError( + f"Unable to match input '{detail['name']}' in provided " + f"inputs" + ) + value = inputs[matched_key] + interpreter.set_tensor(detail["index"], value) + else: + values = inputs + if not isinstance(values, (list, tuple)): + values = [values] + if len(values) != len(input_details): + raise ValueError( + "Number of provided inputs does not match interpreter signature" + ) + for detail, value in zip(input_details, values): + interpreter.set_tensor(detail["index"], value) + + +def _get_interpreter_outputs(interpreter): + output_details = interpreter.get_output_details() + outputs = [ + interpreter.get_tensor(detail["index"]) for detail in output_details + ] + return outputs[0] if len(outputs) == 1 else outputs + + +@pytest.mark.skipif( + not tensorflow.available, + reason="TensorFlow is required for LiteRT export tests.", +) +@pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="`export_litert` currently supports the TensorFlow backend only.", +) +@pytest.mark.skipif( + testing.tensorflow_uses_gpu(), + reason="LiteRT export tests are only run on CPU to avoid CI issues.", +) +class ExportLitertTest(testing.TestCase): + """Test suite for LiteRT (TFLite) model export functionality. + + Tests use AI Edge LiteRT interpreter when available, otherwise fall back + to TensorFlow Lite interpreter for validation. + """ + + @parameterized.named_parameters(named_product(model_type=model_types)) + def test_standard_model_export(self, model_type): + """Test exporting standard model types to LiteRT format.""" + if model_type == "lstm" and not AI_EDGE_LITERT_AVAILABLE: + self.skipTest("LSTM models require AI Edge LiteRT interpreter.") + + temp_filepath = os.path.join( + self.get_temp_dir(), "exported_model.tflite" + ) + model = get_model(model_type) + batch_size = 1 # LiteRT expects batch_size=1 + if model_type == "lstm": + ref_input = np.random.normal(size=(batch_size, 4, 10)) + else: + ref_input = np.random.normal(size=(batch_size, 10)) + ref_input = ref_input.astype("float32") + ref_output = _convert_to_numpy(model(ref_input)) + + # Test with model.export() + model.export(temp_filepath, format="litert") + export_path = temp_filepath + self.assertTrue(os.path.exists(export_path)) + + interpreter = LiteRtInterpreter(model_path=export_path) + interpreter.allocate_tensors() + _set_interpreter_inputs(interpreter, ref_input) + interpreter.invoke() + litert_output = _get_interpreter_outputs(interpreter) + + self.assertAllClose(ref_output, litert_output, atol=1e-4, rtol=1e-4) + + @parameterized.named_parameters( + named_product(struct_type=["tuple", "array", "dict"]) + ) + def test_model_with_input_structure(self, struct_type): + """Test exporting models with structured inputs (tuple/array/dict).""" + batch_size = 1 # LiteRT expects batch_size=1 + base_input = np.random.normal(size=(batch_size, 10)).astype("float32") + + if struct_type == "tuple": + # Use Functional API for proper Input layer handling + input1 = layers.Input(shape=(10,), name="input_1") + input2 = layers.Input(shape=(10,), name="input_2") + output = layers.Add()([input1, input2]) + model = models.Model(inputs=[input1, input2], outputs=output) + ref_input = (base_input, base_input * 2) + elif struct_type == "array": + # Use Functional API for proper Input layer handling + input1 = layers.Input(shape=(10,), name="input_1") + input2 = layers.Input(shape=(10,), name="input_2") + output = layers.Add()([input1, input2]) + model = models.Model(inputs=[input1, input2], outputs=output) + ref_input = [base_input, base_input * 2] + elif struct_type == "dict": + # Use Functional API for proper Input layer handling + input1 = layers.Input(shape=(10,), name="x") + input2 = layers.Input(shape=(10,), name="y") + output = layers.Add()([input1, input2]) + model = models.Model( + inputs={"x": input1, "y": input2}, outputs=output + ) + ref_input = {"x": base_input, "y": base_input * 2} + else: + raise AssertionError("Unexpected structure type") + + temp_filepath = os.path.join( + self.get_temp_dir(), "exported_model.tflite" + ) + ref_output = _convert_to_numpy( + model(tree.map_structure(ops.convert_to_tensor, ref_input)) + ) + + # Test with model.export() + model.export(temp_filepath, format="litert") + export_path = temp_filepath + interpreter = LiteRtInterpreter(model_path=export_path) + interpreter.allocate_tensors() + + feed_inputs = ref_input + if isinstance(feed_inputs, tuple): + feed_inputs = list(feed_inputs) + _set_interpreter_inputs(interpreter, feed_inputs) + interpreter.invoke() + litert_output = _get_interpreter_outputs(interpreter) + + self.assertAllClose(ref_output, litert_output, atol=1e-4, rtol=1e-4) + + # Verify export still works after saving/loading via saving_lib. + archive_path = os.path.join(self.get_temp_dir(), "revived.keras") + saving_lib.save_model(model, archive_path) + revived_model = saving_lib.load_model(archive_path) + revived_output = _convert_to_numpy(revived_model(ref_input)) + self.assertAllClose(ref_output, revived_output) + + def test_model_with_multiple_inputs(self): + """Test exporting models with multiple inputs and batch resizing.""" + temp_filepath = os.path.join( + self.get_temp_dir(), "exported_model.tflite" + ) + + # Use Functional API for proper Input layer handling + input_x = layers.Input(shape=(10,), name="x") + input_y = layers.Input(shape=(10,), name="y") + output = layers.Add()([input_x, input_y]) + model = models.Model(inputs=[input_x, input_y], outputs=output) + + batch_size = 1 # LiteRT expects batch_size=1 + ref_input_x = np.random.normal(size=(batch_size, 10)).astype("float32") + ref_input_y = np.random.normal(size=(batch_size, 10)).astype("float32") + ref_output = _convert_to_numpy(model([ref_input_x, ref_input_y])) + + # Test with model.export() + model.export(temp_filepath, format="litert") + export_path = temp_filepath + interpreter = LiteRtInterpreter(model_path=export_path) + interpreter.allocate_tensors() + + _set_interpreter_inputs(interpreter, [ref_input_x, ref_input_y]) + interpreter.invoke() + litert_output = _get_interpreter_outputs(interpreter) + + self.assertAllClose(ref_output, litert_output, atol=1e-4, rtol=1e-4) + + # Test with a different batch size by resizing interpreter inputs. + larger_x = np.concatenate([ref_input_x, ref_input_x], axis=0) + larger_y = np.concatenate([ref_input_y, ref_input_y], axis=0) + input_details = interpreter.get_input_details() + interpreter.resize_tensor_input( + input_details[0]["index"], larger_x.shape + ) + interpreter.resize_tensor_input( + input_details[1]["index"], larger_y.shape + ) + interpreter.allocate_tensors() + _set_interpreter_inputs(interpreter, [larger_x, larger_y]) + interpreter.invoke() + larger_output = _get_interpreter_outputs(interpreter) + larger_ref_output = _convert_to_numpy(model([larger_x, larger_y])) + self.assertAllClose( + larger_ref_output, larger_output, atol=1e-4, rtol=1e-4 + ) + + def test_export_with_custom_input_signature(self): + """Test exporting with custom input signature specification.""" + model = get_model("sequential") + temp_filepath = os.path.join( + self.get_temp_dir(), "exported_model.tflite" + ) + input_signature = [layers.InputSpec(shape=(None, 10), dtype="float32")] + + # Test with model.export() + model.export( + temp_filepath, + format="litert", + input_signature=input_signature, + ) + export_path = temp_filepath + self.assertTrue(os.path.exists(export_path)) + + interpreter = LiteRtInterpreter(model_path=export_path) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + self.assertEqual(len(input_details), 1) + self.assertEqual(tuple(input_details[0]["shape"][1:]), (10,)) + + def test_multi_output_model_export(self): + """Test exporting multi-output models.""" + model = get_model("multi_output") + + # Build the model + ref_input = np.random.normal(size=(3, 10)).astype("float32") + model(ref_input) + + temp_filepath = os.path.join( + self.get_temp_dir(), "exported_model.tflite" + ) + model.export(temp_filepath, format="litert") + + tflite_path = temp_filepath + self.assertTrue(os.path.exists(tflite_path)) + + # Test inference + interpreter = LiteRtInterpreter(model_path=tflite_path) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + self.assertEqual(len(output_details), 2) + + test_input = np.random.random(input_details[0]["shape"]).astype( + np.float32 + ) + interpreter.set_tensor(input_details[0]["index"], test_input) + interpreter.invoke() + + for detail in output_details: + output = interpreter.get_tensor(detail["index"]) + self.assertIsInstance(output, np.ndarray) + + def test_export_with_verbose(self): + """Test export with verbose output.""" + model = get_model("sequential") + dummy_input = np.random.random((3, 10)).astype(np.float32) + model(dummy_input) + + temp_filepath = os.path.join( + self.get_temp_dir(), "exported_model.tflite" + ) + + # Export with verbose=True + model.export(temp_filepath, format="litert", verbose=True) + + tflite_path = temp_filepath + self.assertTrue(os.path.exists(tflite_path)) + + # Verify the exported model works + interpreter = LiteRtInterpreter(model_path=tflite_path) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(len(input_details), 1) + + def test_export_error_handling(self): + """Test error handling in export API.""" + model = get_model("sequential") + dummy_input = np.random.random((3, 10)).astype(np.float32) + model(dummy_input) + + temp_filepath = os.path.join( + self.get_temp_dir(), "exported_model.tflite" + ) + + # Test with invalid format + with self.assertRaises(ValueError): + model.export(temp_filepath, format="invalid_format") + + def test_export_invalid_filepath(self): + """Test that export fails with invalid file extension.""" + model = get_model("sequential") + dummy_input = np.random.random((3, 10)).astype(np.float32) + model(dummy_input) + + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model.txt") + + # Should raise AssertionError for wrong extension + with self.assertRaises(AssertionError): + model.export(temp_filepath, format="litert") + + def test_export_subclass_model(self): + """Test exporting subclass models (uses wrapper conversion path).""" + if LiteRtInterpreter is None: + self.skipTest("No LiteRT interpreter available") + + model = get_model("subclass") + temp_filepath = os.path.join( + self.get_temp_dir(), "exported_model.tflite" + ) + + batch_size = 1 + ref_input = np.random.normal(size=(batch_size, 10)).astype("float32") + ref_output = _convert_to_numpy(model(ref_input)) + + # Export subclass model - this tests wrapper-based conversion + model.export(temp_filepath, format="litert") + self.assertTrue(os.path.exists(temp_filepath)) + + # Verify inference + interpreter = LiteRtInterpreter(model_path=temp_filepath) + interpreter.allocate_tensors() + _set_interpreter_inputs(interpreter, ref_input) + interpreter.invoke() + litert_output = _get_interpreter_outputs(interpreter) + + self.assertAllClose(ref_output, litert_output, atol=1e-4, rtol=1e-4) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/keras/src/models/model.py b/keras/src/models/model.py index e8fa6415b10..7bfb4d6c95c 100644 --- a/keras/src/models/model.py +++ b/keras/src/models/model.py @@ -136,7 +136,6 @@ def call(self, inputs, training=False): keras.Input(shape=(None, None, 3)), keras.layers.Conv2D(filters=32, kernel_size=3), ]) - ``` """ def __new__(cls, *args, **kwargs): @@ -569,8 +568,8 @@ def export( filepath: `str` or `pathlib.Path` object. The path to save the artifact. format: `str`. The export format. Supported values: - `"tf_saved_model"` and `"onnx"`. Defaults to - `"tf_saved_model"`. + `"tf_saved_model"`, `"onnx"`, `"openvino"`, and `"litert"`. + Defaults to `"tf_saved_model"`. verbose: `bool`. Whether to print a message during export. Defaults to `None`, which uses the default value set by different backends and formats. @@ -593,6 +592,18 @@ def export( provided, they will be automatically computed. - `opset_version`: Optional `int`. Specific to `format="onnx"`. An integer value that specifies the ONNX opset version. + - `allow_custom_ops`: Optional `bool`. Specific to + `format="litert"`. + Whether to allow custom operations during conversion. + Defaults to `False`. + - `enable_select_tf_ops`: Optional `bool`. Specific to + `format="litert"`. + Whether to enable TensorFlow Select ops for unsupported + operations. Defaults to `False`. + - `optimizations`: Optional `list`. Specific to + `format="litert"`. + List of optimizations to apply (e.g., + `[tf.lite.Optimize.DEFAULT]`). **Note:** This feature is currently supported only with TensorFlow, JAX and Torch backends. @@ -627,18 +638,44 @@ def export( } predictions = ort_session.run(None, ort_inputs) ``` + + Here's how to export a LiteRT (TFLite) for inference. + + ```python + # Export the model as a LiteRT artifact + model.export("path/to/location", format="litert") + + # Load the artifact in a different process/environment + interpreter = tf.lite.Interpreter(model_path="path/to/location") + interpreter.allocate_tensors() + interpreter.set_tensor( + interpreter.get_input_details()[0]['index'], input_data + ) + interpreter.invoke() + output_data = interpreter.get_tensor( + interpreter.get_output_details()[0]['index'] + ) + ``` """ + from keras.src.export import export_litert from keras.src.export import export_onnx from keras.src.export import export_openvino from keras.src.export import export_saved_model - available_formats = ("tf_saved_model", "onnx", "openvino") + available_formats = ("tf_saved_model", "onnx", "openvino", "litert") if format not in available_formats: raise ValueError( f"Unrecognized format={format}. Supported formats are: " f"{list(available_formats)}." ) + # Check if LiteRT export is available (requires TensorFlow) + if format == "litert" and export_litert is None: + raise ImportError( + "LiteRT export requires TensorFlow to be installed. " + "Please install TensorFlow: `pip install tensorflow`" + ) + if format == "tf_saved_model": export_saved_model( self, @@ -663,6 +700,14 @@ def export( input_signature=input_signature, **kwargs, ) + elif format == "litert": + export_litert( + self, + filepath, + verbose=verbose, + input_signature=input_signature, + **kwargs, + ) @classmethod def from_config(cls, config, custom_objects=None): diff --git a/keras/src/utils/module_utils.py b/keras/src/utils/module_utils.py index 286394a9935..577d08a7fd4 100644 --- a/keras/src/utils/module_utils.py +++ b/keras/src/utils/module_utils.py @@ -59,3 +59,4 @@ def __repr__(self): dmtree = LazyModule("tree") tf2onnx = LazyModule("tf2onnx") grain = LazyModule("grain") +litert = LazyModule("ai_edge_litert") diff --git a/requirements-tensorflow-cuda.txt b/requirements-tensorflow-cuda.txt index f895f022415..05125b381c9 100644 --- a/requirements-tensorflow-cuda.txt +++ b/requirements-tensorflow-cuda.txt @@ -1,6 +1,7 @@ # Tensorflow with cuda support. tensorflow[and-cuda]~=2.18.1 tf2onnx +ai-edge-litert # Torch cpu-only version (needed for testing). --extra-index-url https://download.pytorch.org/whl/cpu