Skip to content

Commit 9ce72be

Browse files
authored
Add --outputs_as_nchw option to transpose output to from nhwc to nchw (#1979)
* add output_as_nchw Signed-off-by: Deyu Huang <[email protected]> * fix node replace logic Signed-off-by: Deyu Huang <[email protected]> * add tests for outputs as nchw Signed-off-by: Deyu Huang <[email protected]> * add it into function and doc Signed-off-by: Deyu Huang <[email protected]> * fix output_names_with_port range Signed-off-by: Deyu Huang <[email protected]> * fix the input_as_nchw description Signed-off-by: Deyu Huang <[email protected]> * change tests name Signed-off-by: Deyu Huang <[email protected]>
1 parent fa0b6cf commit 9ce72be

File tree

5 files changed

+114
-42
lines changed

5 files changed

+114
-42
lines changed

README.md

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,8 @@ import tf2onnx
292292
model_proto, external_tensor_storage = tf2onnx.convert.from_keras(model,
293293
input_signature=None, opset=None, custom_ops=None,
294294
custom_op_handlers=None, custom_rewriter=None,
295-
inputs_as_nchw=None, extra_opset=None shape_override=None,
296-
target=None, large_model=False, output_path=None)
295+
inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None,
296+
shape_override=None, target=None, large_model=False, output_path=None)
297297
298298
Args:
299299
model: the tf.keras model we want to convert
@@ -307,7 +307,8 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_keras(model,
307307
custom_rewriter: list of custom graph rewriters
308308
extra_opset: list of extra opset's, for example the opset's used by custom ops
309309
shape_override: dict with inputs that override the shapes given by tensorflow
310-
inputs_as_nchw: transpose inputs in list from nchw to nhwc
310+
inputs_as_nchw: transpose inputs in list from nhwc to nchw
311+
outputs_as_nchw: transpose outputs in list from nhwc to nchw
311312
large_model: use the ONNX external tensor storage format
312313
output_path: save model to output_path
313314
@@ -323,8 +324,8 @@ import tf2onnx
323324
324325
model_proto, external_tensor_storage = tf2onnx.convert.from_function(function,
325326
input_signature=None, opset=None, custom_ops=None,
326-
custom_op_handlers=None, custom_rewriter=None,
327-
inputs_as_nchw=None, extra_opset=None, shape_override=None,
327+
custom_op_handlers=None, custom_rewriter=None, inputs_as_nchw=None,
328+
outputs_as_nchw=None, extra_opset=None, shape_override=None,
328329
target=None, large_model=False, output_path=None)
329330
330331
Args:
@@ -339,7 +340,8 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_function(function,
339340
custom_rewriter: list of custom graph rewriters
340341
extra_opset: list of extra opset's, for example the opset's used by custom ops
341342
shape_override: dict with inputs that override the shapes given by tensorflow
342-
inputs_as_nchw: transpose inputs in list from nchw to nhwc
343+
inputs_as_nchw: transpose inputs in list from nhwc to nchw
344+
outputs_as_nchw: transpose outputs in list from nhwc to nchw
343345
large_model: use the ONNX external tensor storage format
344346
output_path: save model to output_path
345347
@@ -354,7 +356,7 @@ import tf2onnx
354356
model_proto, external_tensor_storage = tf2onnx.convert.from_graph_def(graph_def,
355357
name=None, input_names=None, output_names=None, opset=None,
356358
custom_ops=None, custom_op_handlers=None, custom_rewriter=None,
357-
inputs_as_nchw=None, extra_opset=None,
359+
inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None,
358360
shape_override=None, target=None, large_model=False,
359361
output_path=None)
360362
@@ -369,7 +371,8 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_graph_def(graph_def,
369371
custom_rewriter: list of custom graph rewriters
370372
extra_opset: list of extra opset's, for example the opset's used by custom ops
371373
shape_override: dict with inputs that override the shapes given by tensorflow
372-
inputs_as_nchw: transpose inputs in list from nchw to nhwc
374+
inputs_as_nchw: transpose inputs in list from nhwc to nchw
375+
outputs_as_nchw: transpose outputs in list from nhwc to nchw
373376
large_model: use the ONNX external tensor storage format
374377
output_path: save model to output_path
375378
@@ -383,8 +386,8 @@ import tf2onnx
383386
384387
model_proto, external_tensor_storage = tf2onnx.convert.from_tflite(tflite_path,
385388
input_names=None, output_names=None, opset=None, custom_ops=None, custom_op_handlers=None,
386-
custom_rewriter=None, inputs_as_nchw=None, extra_opset=None, shape_override=None, target=None,
387-
large_model=False, output_path=None):
389+
custom_rewriter=None, inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None,
390+
shape_override=None, target=None, large_model=False, output_path=None):
388391
389392
Args:
390393
tflite_path: the tflite model file full path
@@ -396,7 +399,8 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_tflite(tflite_path,
396399
runtime can still open the model. Type is a dictionary `{op name: domain}`.
397400
custom_op_handlers: dictionary of custom ops handlers
398401
custom_rewriter: list of custom graph rewriters
399-
inputs_as_nchw: transpose inputs in list from nchw to nhwc
402+
inputs_as_nchw: transpose inputs in list from nhwc to nchw
403+
outputs_as_nchw: transpose outputs in list from nhwc to nchw
400404
extra_opset: list of extra opset's, for example the opset's used by custom ops
401405
shape_override: dict with inputs that override the shapes given by tensorflow
402406
target: list of workarounds applied to help certain platforms

tests/backend_test_base.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import onnx
2121
from common import get_test_config
2222
from tfjs_runner import run_tfjs
23+
from tf2onnx import constants
2324
from tf2onnx import utils
2425
from tf2onnx.tfonnx import process_tf_graph
2526
from tf2onnx import optimizer
@@ -366,6 +367,7 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
366367
graph_def_path = os.path.join(self.test_data_directory, self._testMethodName + "_after_tf_optimize.pb")
367368
utils.save_protobuf(graph_def_path, graph_def)
368369
self.logger.debug("created file %s", graph_def_path)
370+
tfl_process_args = process_args.copy()
369371

370372
if test_tfjs:
371373
tfjs_path = self.convert_to_tfjs(graph_def_path, output_names_with_port)
@@ -395,6 +397,10 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
395397
g = optimizer.optimize_graph(g, catch_errors=False)
396398
actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model,
397399
use_custom_ops=use_custom_ops)
400+
if 'outputs_as_nchw' in tfl_process_args:
401+
for output_name in tfl_process_args['outputs_as_nchw']:
402+
i = output_names_with_port.index(output_name)
403+
actual[i] = np.transpose(actual[i], constants.NCHW_TO_NHWC)
398404

399405
self.assert_results_equal(expected, actual, rtol, atol, mtol, check_value, check_shape, check_dtype)
400406
self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker)
@@ -410,12 +416,14 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
410416
if run_tfl_consistency_test:
411417
self.assert_results_equal(expected, tfl_res, rtol, atol, mtol, check_value, check_shape, check_dtype)
412418

413-
tfl_process_args = process_args.copy()
414419
if 'inputs_as_nchw' in tfl_process_args:
415420
nchw_inps_with_port = tfl_process_args['inputs_as_nchw']
416421
tfl_process_args['inputs_as_nchw'] = [i.split(':')[0] for i in nchw_inps_with_port]
417422
input_names_without_port = [inp.split(':')[0] for inp in feed_dict.keys()]
418-
423+
if 'outputs_as_nchw' in tfl_process_args:
424+
nchw_outps_with_port = tfl_process_args['outputs_as_nchw']
425+
tfl_process_args['outputs_as_nchw'] = [i.split(':')[0] for i in nchw_outps_with_port]
426+
output_names_with_port = [i.split(':')[0] for i in nchw_outps_with_port]
419427
g = process_tf_graph(None, opset=self.config.opset,
420428
input_names=input_names_without_port,
421429
output_names=tfl_outputs,
@@ -427,6 +435,10 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
427435
onnx_feed_dict_without_port = {k.split(':')[0]: v for k, v in onnx_feed_dict.items()}
428436
onnx_tfl_res = self.run_backend(g, tfl_outputs, onnx_feed_dict_without_port,
429437
postfix="_from_tflite", use_custom_ops=use_custom_ops)
438+
if 'outputs_as_nchw' in tfl_process_args:
439+
for output_name in tfl_process_args['outputs_as_nchw']:
440+
i = output_names_with_port.index(output_name)
441+
onnx_tfl_res[i] = np.transpose(onnx_tfl_res[i], constants.NCHW_TO_NHWC)
430442

431443
self.assert_results_equal(tfl_res, onnx_tfl_res, rtol, atol, mtol, check_value, check_shape, check_dtype)
432444
self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker)
@@ -456,6 +468,10 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
456468
g = optimizer.optimize_graph(g)
457469
onnx_tfjs_res = self.run_backend(g, None, onnx_feed_dict, large_model,
458470
postfix="_from_tfjs", use_custom_ops=use_custom_ops)
471+
if 'outputs_as_nchw' in tfl_process_args:
472+
for output_name in tfl_process_args['outputs_as_nchw']:
473+
i = output_names_with_port.index(output_name)
474+
onnx_tfjs_res[i] = np.transpose(onnx_tfjs_res[i], constants.NCHW_TO_NHWC)
459475

460476
self.assert_results_equal(tfjs_res, onnx_tfjs_res, rtol, atol, mtol, check_value, check_shape,
461477
check_dtype=False)

tests/test_backend.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ def func(x):
712712
graph_validator=lambda g: (check_op_count(g, "RandomUniform", 0) and
713713
check_op_count(g, "RandomUniformLike", 0)))
714714

715-
def test_conv2d_with_input_transpose(self):
715+
def test_inputs_as_nchw_arg(self):
716716
x_shape = [2, 32, 32, 3]
717717
kernel_shape = [3, 3, 3, 3]
718718
x_val = make_xval(x_shape)
@@ -725,6 +725,17 @@ def func(x):
725725
process_args={"inputs_as_nchw": [_INPUT]},
726726
onnx_feed_dict={_INPUT: x_val_for_onnx})
727727

728+
def test_outputs_as_nchw_arg(self):
729+
x_shape = [2, 32, 32, 3]
730+
kernel_shape = [3, 3, 3, 3]
731+
x_val = make_xval(x_shape)
732+
def func(x):
733+
kernel = tf.constant(make_xval(kernel_shape), dtype=tf.float32, name='kernel')
734+
conv = tf.nn.conv2d(x, kernel, strides=[1, 1, 1, 1], padding="SAME")
735+
return tf.identity(conv, name=_TFOUTPUT)
736+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05,
737+
process_args={"outputs_as_nchw": [_OUTPUT]})
738+
728739
@skip_tflite("TFlite adds ops that obscure pattern")
729740
@check_tf_min_version("1.15")
730741
def test_conv1d_dilations_rewriter(self):

0 commit comments

Comments
 (0)