|
33 | 33 | from tf2onnx.rewriter.rnn import rewrite_single_direction_gru
|
34 | 34 | from tf2onnx.rewriter.rnn import rewrite_single_direction_grublock
|
35 | 35 | from tf2onnx.rewriter.rnn import rewrite_single_direction_lstm, rewrite_bi_direction_lstm
|
36 |
| -from tf2onnx.shape_inference import infer_shape_for_graph, set_shape_from_inputs_broadcast |
| 36 | +from tf2onnx.shape_inference import infer_shape_for_graph |
37 | 37 | from tf2onnx.utils import port_name
|
38 | 38 |
|
39 | 39 | logging.basicConfig(level=logging.INFO)
|
@@ -1635,8 +1635,10 @@ def logical_compare_op(ctx, node, name, args):
|
1635 | 1635 |
|
1636 | 1636 | def logical_compareeq_op(ctx, node, name, args):
|
1637 | 1637 | logical_compare_op(ctx, node, name, [])
|
1638 |
| - ctx.insert_new_node_on_output("Not", node.output[0], name=utils.make_name(name), |
1639 |
| - shapes=ctx.get_shape(node.output[0]), dtypes=ctx.get_dtype(node.output[0])) |
| 1638 | + output_name = node.output[0] |
| 1639 | + new_node = ctx.insert_new_node_on_output("Not", output_name, name=utils.make_name(name)) |
| 1640 | + ctx.copy_shape(output_name, new_node.output[0]) |
| 1641 | + ctx.set_dtype(new_node.output[0], ctx.get_dtype(output_name)) |
1640 | 1642 |
|
1641 | 1643 |
|
1642 | 1644 | def where_op(ctx, node, name, args):
|
|
0 commit comments