Skip to content

Commit 89c4c5c

Browse files
authored
The from_tflite() function should accept None as default value of input_names and output_names. (#1967)
* The from_tflite() function should not change the value of None to an empty list for input_names and output_names. * Change the way to validate a list is None or Emtpy. Signed-off-by: Jay Zhang <[email protected]>
1 parent 9cea907 commit 89c4c5c

File tree

3 files changed

+22
-10
lines changed

3 files changed

+22
-10
lines changed

tests/test_api.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,27 @@ def test_tflite(self):
237237

238238
x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2))
239239
model_proto, _ = tf2onnx.convert.from_tflite("tests/models/regression/tflite/test_api_model.tflite",
240-
input_names=['input'], output_names=['output'],
240+
input_names=["input"], output_names=["output"],
241241
output_path=output_path)
242-
output_names = [n.name for n in model_proto.graph.output]
243-
oy = self.run_onnxruntime(output_path, {"input": x_val}, output_names)
244-
self.assertTrue(output_names[0] == "output")
242+
actual_output_names = [n.name for n in model_proto.graph.output]
243+
oy = self.run_onnxruntime(output_path, {"input": x_val}, actual_output_names)
244+
245+
self.assertTrue(actual_output_names[0] == "output")
246+
exp_result = tf.add(x_val, x_val)
247+
self.assertAllClose(exp_result, oy[0], rtol=0.1, atol=0.1)
248+
249+
@check_tf_min_version("2.0")
250+
def test_tflite_without_input_output_names(self):
251+
output_path = os.path.join(self.test_data_directory, "model.onnx")
252+
253+
x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2))
254+
model_proto, _ = tf2onnx.convert.from_tflite("tests/models/regression/tflite/test_api_model.tflite",
255+
output_path=output_path)
256+
actual_input_names = [n.name for n in model_proto.graph.input]
257+
actual_output_names = [n.name for n in model_proto.graph.output]
258+
oy = self.run_onnxruntime(output_path, {actual_input_names[0]: x_val}, output_names=None)
259+
260+
self.assertTrue(actual_output_names[0] == "output")
245261
exp_result = tf.add(x_val, x_val)
246262
self.assertAllClose(exp_result, oy[0], rtol=0.1, atol=0.1)
247263

tf2onnx/convert.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -663,10 +663,6 @@ def from_tflite(tflite_path, input_names=None, output_names=None, opset=None, cu
663663
"""
664664
if not tflite_path:
665665
raise ValueError("tflite_path needs to be provided")
666-
if not input_names:
667-
input_names = []
668-
if not output_names:
669-
output_names = []
670666

671667
with tf.device("/cpu:0"):
672668
model_proto, external_tensor_storage = _convert_common(

tf2onnx/tflite_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,9 @@ def graphs_from_tflite(tflite_path, input_names=None, output_names=None):
156156
if is_main_g:
157157
# Override IO in main graph
158158
utils.check_io(input_names, output_names, output_shapes.keys())
159-
if input_names is not None:
159+
if input_names:
160160
g_inputs = input_names
161-
if output_names is not None:
161+
if output_names:
162162
g_outputs = output_names
163163
g = Graph(onnx_nodes, output_shapes, dtypes, input_names=g_inputs, output_names=g_outputs,
164164
is_subgraph=not is_main_g, graph_name=graph_name)

0 commit comments

Comments
 (0)