Skip to content

Commit 81bf626

Browse files
ecalubaquibGoogle-ML-Automation
authored andcommitted
Move converter related tflite functions to tensorflow/lite repo
PiperOrigin-RevId: 688270228
1 parent 11eeff0 commit 81bf626

File tree

5 files changed

+14
-53
lines changed

5 files changed

+14
-53
lines changed

jax/experimental/jax2tf/README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ For more involved examples, please see examples involving:
103103

104104
* SavedModel for archival ([examples below](#usage-saved-model)), including
105105
saving [batch-polymorphic functions](#shape-polymorphic-conversion),
106-
* TensorFlow Lite ([examples](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/tflite/mnist/README.md)),
107106
* TensorFlow.js ([examples](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/tf_js/quickdraw/README.md)),
108107
* TFX ([examples](https://github.com/tensorflow/tfx/blob/master/tfx/examples/penguin/README.md#instructions-for-using-flax)),
109108
* TensorFlow Hub and Keras ([examples](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/README.md)).

jax/experimental/jax2tf/examples/README.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ This directory contains a number of examples of using the
99
* save SavedModel from trained MNIST models, using both Flax and pure JAX.
1010
* reuse the feature-extractor part of the trained MNIST model
1111
in a larger TensorFlow Keras model.
12-
* use Flax models with TensorFlow Serving, TensorFlow JavaScript, and TensorFlow Lite.
12+
* use Flax models with TensorFlow Serving, TensorFlow JavaScript.
1313

1414
You can also find usage examples in other projects:
1515

@@ -176,16 +176,15 @@ At the moment, the open-source TensorFlow model server is missing XLA support,
176176
but the Google version can be used, as shown in the
177177
[serving examples](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/serving/README.md).
178178

179-
# Using jax2tf with TensorFlow Lite and TensorFlow JavaScript
179+
# Using jax2tf with TensorFlow JavaScript
180180

181181
A jax2tf-generated SavedModel can also be converted to a format usable with
182-
TensorFlow Lite or TensorFlow.js, by using the appropriate converters from SavedModel.
182+
TensorFlow.js, by using the appropriate converters from SavedModel.
183183
At the moment, these converters may reject some jax2tf-generated SavedModels due to
184184
some ops not yet implemented in the converters. As a partial workaround, one
185185
can pass the `enable_xla=False` parameter to `jax2tf.convert` to direct
186186
`jax2tf` to avoid problematic ops. This will increase the coverage, and in fact
187187
most, but not all, Flax examples can be converted this way.
188188

189-
Check out the [MNIST TensorFlow Lite](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/tflite/mnist/README.md)
190-
and the
189+
Check out
191190
[Quickdraw TensorFlow.js example](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/tf_js/quickdraw/README.md).

jax/experimental/jax2tf/g3doc/BUILD

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ licenses(["notice"])
1515

1616
package(
1717
default_applicable_licenses = [],
18-
default_visibility = ["//jax/experimental/jax2tf:__subpackages__"],
18+
default_visibility = [
19+
"//jax/experimental/jax2tf:__subpackages__",
20+
"//third_party/tensorflow/lite/experimental/mlir/testing/jax:__subpackages__",
21+
],
1922
)
2023

2124
filegroup(

jax/experimental/jax2tf/tests/converters.py

Lines changed: 2 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from typing import Any
2121

2222
from jax.experimental import jax2tf
23-
import tensorflow as tf
2423
import tensorflowjs as tfjs
2524

2625
from jax.experimental.jax2tf.tests.model_harness import ModelHarness
@@ -51,56 +50,14 @@ def jax2tfjs(harness: ModelHarness):
5150
model_dir=model_dir)
5251

5352

54-
def jax2tflite(harness: ModelHarness, use_flex_ops: bool = False):
55-
"""Returns a converter with Flex ops linked in iff `use_flex_ops==True`."""
56-
tf_fn = tf.function(
57-
jax2tf_convert(harness, enable_xla=False),
58-
input_signature=harness.tf_input_signature,
59-
autograph=False)
60-
apply_tf = tf_fn.get_concrete_function()
61-
converter = tf.lite.TFLiteConverter.from_concrete_functions([apply_tf], tf_fn)
62-
supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
63-
if use_flex_ops:
64-
supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
65-
converter.target_spec.supported_ops = supported_ops
66-
67-
# Convert the model.
68-
tflite_model = converter.convert()
69-
70-
# Construct an interpreter for doing a numerical comparison.
71-
interpreter = tf.lite.Interpreter(model_content=tflite_model)
72-
interpreter.allocate_tensors()
73-
74-
inputs = interpreter.get_input_details()
75-
output_details = interpreter.get_output_details()
76-
outputs = tuple(interpreter.tensor(out["index"]) for out in output_details)
77-
78-
def apply_tflite(*xs):
79-
assert len(xs) == len(inputs)
80-
for i, x in enumerate(xs):
81-
interpreter.set_tensor(inputs[i]['index'], x)
82-
interpreter.invoke()
83-
if len(outputs) > 1:
84-
return tuple(o() for o in outputs)
85-
else:
86-
return outputs[0]()
87-
88-
return apply_tflite
89-
90-
9153
ALL_CONVERTERS = [
9254
# jax2tf with XLA support (enable_xla=True).
9355
Converter(name='jax2tf_xla', convert_fn=jax2tf_convert),
9456
# jax2tf without XLA support (enable_xla=False).
9557
Converter(
9658
name='jax2tf_noxla',
97-
convert_fn=functools.partial(jax2tf_convert, enable_xla=False)),
59+
convert_fn=functools.partial(jax2tf_convert, enable_xla=False),
60+
),
9861
# Convert JAX to Tensorflow.JS.
9962
Converter(name='jax2tfjs', convert_fn=jax2tfjs, compare_numerics=False),
100-
# Convert JAX to TFLIte.
101-
Converter(name='jax2tflite', convert_fn=jax2tflite),
102-
# Convert JAX to TFLIte with support for Flex ops.
103-
Converter(
104-
name='jax2tflite+flex',
105-
convert_fn=functools.partial(jax2tflite, use_flex_ops=True))
10663
]

jax/experimental/jax2tf/tests/flax_models/BUILD

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ licenses(["notice"])
1919

2020
package(
2121
default_applicable_licenses = [],
22-
default_visibility = ["//jax/experimental/jax2tf:__subpackages__"],
22+
default_visibility = [
23+
"//jax/experimental/jax2tf:__subpackages__",
24+
"//third_party/tensorflow/lite/experimental/mlir/testing/jax/test_models:__subpackages__",
25+
],
2326
)
2427

2528
py_library(

0 commit comments

Comments
 (0)