|
33 | 33 | from tf2onnx.rewriter.rnn import rewrite_single_direction_gru
|
34 | 34 | from tf2onnx.rewriter.rnn import rewrite_single_direction_grublock
|
35 | 35 | from tf2onnx.rewriter.rnn import rewrite_single_direction_lstm, rewrite_bi_direction_lstm
|
36 |
| -from tf2onnx.shape_inference import infer_shape_for_graph, set_shape_from_inputs_broadcast |
| 36 | +from tf2onnx.shape_inference import infer_shape_for_graph |
37 | 37 | from tf2onnx.utils import port_name
|
38 | 38 |
|
39 | 39 | logging.basicConfig(level=logging.INFO)
|
@@ -1633,6 +1633,13 @@ def logical_compare_op(ctx, node, name, args):
|
1633 | 1633 | ctx.copy_shape(inp, inp_cast.output[0])
|
1634 | 1634 | ctx.set_dtype(inp_cast.output[0], target_dtype)
|
1635 | 1635 |
|
| 1636 | +def logical_compareeq_op(ctx, node, name, args): |
| 1637 | + logical_compare_op(ctx, node, name, []) |
| 1638 | + output_name = node.output[0] |
| 1639 | + new_node = ctx.insert_new_node_on_output("Not", output_name, name=utils.make_name(name)) |
| 1640 | + ctx.copy_shape(output_name, new_node.output[0]) |
| 1641 | + ctx.set_dtype(new_node.output[0], ctx.get_dtype(output_name)) |
| 1642 | + |
1636 | 1643 |
|
1637 | 1644 | def where_op(ctx, node, name, args):
|
1638 | 1645 | # T_y output = Where(T_x condition), return indices of elements whose value are True
|
@@ -1777,6 +1784,8 @@ def where_op(ctx, node, name, args):
|
1777 | 1784 | "FusedBatchNormV2": (fused_batchnorm_op7, []),
|
1778 | 1785 | "Greater": (logical_compare_op, []),
|
1779 | 1786 | "Less": (logical_compare_op, []),
|
| 1787 | + "GreaterEqual": (logical_compareeq_op, ["Less"]), |
| 1788 | + "LessEqual": (logical_compareeq_op, ["Greater"]), |
1780 | 1789 | "LogicalAnd": (broadcast_op7, ["And"]),
|
1781 | 1790 | "LogicalOr": (broadcast_op7, ["Or"]),
|
1782 | 1791 | "MatrixBandPart": (matrixbandpart_op, []),
|
@@ -1812,9 +1821,7 @@ def where_op(ctx, node, name, args):
|
1812 | 1821 | "Cosh": (direct_op, []),
|
1813 | 1822 | "Erf": (direct_op, []),
|
1814 | 1823 | "Fill": (fill_op, []),
|
1815 |
| - "Greater": (logical_compare_op, []), |
1816 | 1824 | "IsNan": (direct_op, ["IsNaN"]),
|
1817 |
| - "Less": (logical_compare_op, []), |
1818 | 1825 | "ResizeBilinear": (upsample_op9, ["Upsample", "linear"]),
|
1819 | 1826 | "ResizeNearestNeighbor": (upsample_op9, ["Upsample", "nearest"]),
|
1820 | 1827 | "Sign": (sign_op9, []),
|
@@ -2105,50 +2112,6 @@ def rewrite_constant_fold(g, ops):
|
2105 | 2112 | return ops
|
2106 | 2113 |
|
2107 | 2114 |
|
2108 |
| -def rewrite_logical_compare_with_equal(g, ops): |
2109 |
| - patterns = {"GreaterEqual": "Greater", |
2110 |
| - "LessEqual": "Less"} |
2111 |
| - for p in patterns: |
2112 |
| - pattern = OpTypePattern(p, name='compare_with_equal') |
2113 |
| - compare_name = patterns[p] |
2114 |
| - matcher = GraphMatcher(pattern) |
2115 |
| - match_results = list(matcher.match_ops(ops)) |
2116 |
| - for match in match_results: |
2117 |
| - nodes_to_append = [] |
2118 |
| - compare_e_op = match.get_op('compare_with_equal') |
2119 |
| - data_type = g.get_dtype(compare_e_op.input[0]) |
2120 |
| - compare_input_ids = compare_e_op.input |
2121 |
| - need_cast = g.opset < 9 and data_type not in (onnx_pb.TensorProto.FLOAT16, |
2122 |
| - onnx_pb.TensorProto.FLOAT, |
2123 |
| - onnx_pb.TensorProto.DOUBLE) |
2124 |
| - if need_cast: |
2125 |
| - compare_input_ids = [] |
2126 |
| - for input_id in compare_e_op.input: |
2127 |
| - cast_node = g.make_node("Cast", [input_id], op_name_scope=compare_e_op.name, |
2128 |
| - attr={"to": onnx_pb.TensorProto.FLOAT}, shapes=[g.get_shape(input_id)], |
2129 |
| - dtypes=[onnx_pb.TensorProto.FLOAT]) |
2130 |
| - compare_input_ids.append(cast_node.output[0]) |
2131 |
| - nodes_to_append.append(cast_node) |
2132 |
| - |
2133 |
| - g_node = g.make_node(compare_name, compare_input_ids, op_name_scope=compare_e_op.name, |
2134 |
| - dtypes=[onnx_pb.TensorProto.BOOL]) |
2135 |
| - set_shape_from_inputs_broadcast(g, compare_input_ids, g_node.output[0]) |
2136 |
| - new_shape = g.get_shape(g_node.output[0]) |
2137 |
| - nodes_to_append.append(g_node) |
2138 |
| - |
2139 |
| - e_node = g.make_node("Equal", compare_e_op.input, op_name_scope=compare_e_op.name, |
2140 |
| - shapes=[new_shape], dtypes=[onnx_pb.TensorProto.BOOL]) |
2141 |
| - nodes_to_append.append(e_node) |
2142 |
| - |
2143 |
| - compare_e_op.type = "LogicalOr" |
2144 |
| - compare_e_op.input[0] = g_node.output[0] |
2145 |
| - compare_e_op.input[1] = e_node.output[0] |
2146 |
| - g.set_dtype(compare_e_op.output[0], onnx_pb.TensorProto.BOOL) |
2147 |
| - g.set_shape(compare_e_op.output[0], new_shape) |
2148 |
| - ops.extend(nodes_to_append) |
2149 |
| - return ops |
2150 |
| - |
2151 |
| - |
2152 | 2115 | def rewrite_incomplete_type_support(g, ops, impacted_ops):
|
2153 | 2116 | """
|
2154 | 2117 | for ops that have inclomplete type support, insert casts.
|
@@ -2459,7 +2422,7 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
|
2459 | 2422 | rewrite_leakyrelu, rewrite_conv2d_with_pad,
|
2460 | 2423 | rewrite_single_direction_lstm, rewrite_bi_direction_lstm,
|
2461 | 2424 | rewrite_single_direction_gru, rewrite_single_direction_grublock,
|
2462 |
| - rewrite_bi_direction_gru, rewrite_logical_compare_with_equal, |
| 2425 | + rewrite_bi_direction_gru, |
2463 | 2426 | rewrite_custom_rnn_cell, rewrite_generic_loop, rewrite_cond
|
2464 | 2427 | ]
|
2465 | 2428 |
|
|
0 commit comments