|
| 1 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +# Licensed under the MIT license. |
| 3 | + |
| 4 | +""" |
| 5 | +tf2onnx.rewriter.eye_rewriter - supports tf.eye |
| 6 | +""" |
| 7 | + |
| 8 | +from onnx import onnx_pb |
| 9 | +from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher |
| 10 | + |
| 11 | +# pylint: disable=invalid-name,unused-argument,missing-docstring, unused-variable |
| 12 | + |
| 13 | + |
| 14 | +def rewrite_eye(g, ops): |
| 15 | + # tf.eye is implemented by a sub_graph which contains op "MatrixDiag" or "MatrixSetDiag" while |
| 16 | + # these two ops are un-supported directly in onnx |
| 17 | + # but onnx op EyeLike can be used to map the sub_graph |
| 18 | + # "rewrite_eye" supports tf.eye(tf.shape(x)[i]) and tf.eye(tf.shape(x)[i], tf.shape(x)[j]). |
| 19 | + |
| 20 | + # ConstantOfShape in opset 9 is used, so if opset less than 9 then do nothing |
| 21 | + if g.opset < 9: |
| 22 | + return g.get_nodes() |
| 23 | + |
| 24 | + pattern1 = \ |
| 25 | + OpTypePattern("MatrixDiag", name="output_eye_matrix", inputs=[ |
| 26 | + OpTypePattern("Fill", inputs=[ |
| 27 | + OpTypePattern("Const"), |
| 28 | + OpTypePattern("ConcatV2", inputs=[ |
| 29 | + "*", |
| 30 | + "*", |
| 31 | + OpTypePattern("Pack", inputs=[ |
| 32 | + OpTypePattern("Minimum", name="min_node") |
| 33 | + ]) |
| 34 | + ]) |
| 35 | + ]) |
| 36 | + ]) |
| 37 | + pattern2 = \ |
| 38 | + OpTypePattern("MatrixSetDiag", name="output_eye_matrix", inputs=[ |
| 39 | + OpTypePattern("Fill"), |
| 40 | + OpTypePattern("Fill", inputs=[ |
| 41 | + OpTypePattern("Const"), |
| 42 | + OpTypePattern("ConcatV2", inputs=[ |
| 43 | + "*", |
| 44 | + "*", |
| 45 | + OpTypePattern("Pack", inputs=[ |
| 46 | + OpTypePattern("Minimum", name="min_node") |
| 47 | + ]) |
| 48 | + ]) |
| 49 | + ]) |
| 50 | + ]) |
| 51 | + |
| 52 | + for pattern in [pattern1, pattern2]: |
| 53 | + matcher = GraphMatcher(pattern, allow_reorder=True) |
| 54 | + match_results = list(matcher.match_ops(ops)) |
| 55 | + for match_result in match_results: |
| 56 | + old_output = match_result.get_op("output_eye_matrix") |
| 57 | + output_dtypes = [g.get_dtype(old_output.output[0])] |
| 58 | + output_shapes = [g.get_shape(old_output.output[0])] |
| 59 | + g.remove_node(old_output.name) |
| 60 | + |
| 61 | + min_node = match_result.get_op("min_node") |
| 62 | + num_rows = min_node.inputs[0] |
| 63 | + num_columns = min_node.inputs[1] |
| 64 | + |
| 65 | + # onnx op "EyeLike" need a 2D tensor, so generate it |
| 66 | + num_rows = g.make_node("Unsqueeze", num_rows.output, attr={"axes": [0]}) |
| 67 | + num_columns = g.make_node("Unsqueeze", num_columns.output, attr={"axes": [0]}) |
| 68 | + matrix_shape = g.make_node("Concat", [num_rows.output[0], num_columns.output[0]], attr={"axis": 0}) |
| 69 | + # cast nodes added for "ConstantOfShape" in ONNX only accepts int64 data. |
| 70 | + matrix_shape_int64 = g.make_node("Cast", matrix_shape.output, attr={"to": onnx_pb.TensorProto.INT64}) |
| 71 | + zero_matrix = g.make_node("ConstantOfShape", matrix_shape_int64.output) |
| 72 | + |
| 73 | + new_output = g.make_node("EyeLike", zero_matrix.output, attr={"dtype": output_dtypes[0]}, |
| 74 | + name=old_output.name, shapes=output_shapes, dtypes=output_dtypes) |
| 75 | + |
| 76 | + return g.get_nodes() |
0 commit comments