Skip to content

Commit 95a98fc

Browse files
committed
set shape to fix some unit test
1 parent fd569ed commit 95a98fc

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

tf2onnx/tfonnx.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from tf2onnx.rewriter.rnn import rewrite_single_direction_gru
3434
from tf2onnx.rewriter.rnn import rewrite_single_direction_grublock
3535
from tf2onnx.rewriter.rnn import rewrite_single_direction_lstm, rewrite_bi_direction_lstm
36-
from tf2onnx.shape_inference import infer_shape_for_graph, set_shape_from_inputs_broadcast
36+
from tf2onnx.shape_inference import infer_shape_for_graph
3737
from tf2onnx.utils import port_name
3838

3939
logging.basicConfig(level=logging.INFO)
@@ -1635,8 +1635,10 @@ def logical_compare_op(ctx, node, name, args):
16351635

16361636
def logical_compareeq_op(ctx, node, name, args):
16371637
logical_compare_op(ctx, node, name, [])
1638-
ctx.insert_new_node_on_output("Not", node.output[0], name=utils.make_name(name),
1639-
shapes=ctx.get_shape(node.output[0]), dtypes=ctx.get_dtype(node.output[0]))
1638+
output_name = node.output[0]
1639+
new_node = ctx.insert_new_node_on_output("Not", output_name, name=utils.make_name(name))
1640+
ctx.copy_shape(output_name, new_node.output[0])
1641+
ctx.set_dtype(new_node.output[0], ctx.get_dtype(output_name))
16401642

16411643

16421644
def where_op(ctx, node, name, args):

0 commit comments

Comments
 (0)