Skip to content

Commit 1c9c02d

Browse files
Merge pull request #1244 from onnx/tom/SimplifyTAGetItem
Simplified TA get item to avoid shape bug
2 parents 200ceed + a074cfc commit 1c9c02d

File tree

1 file changed

+1
-7
lines changed

1 file changed

+1
-7
lines changed

tf2onnx/onnx_opset/controlflow.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from tf2onnx import utils
2020
from tf2onnx.handler import tf_op
2121
from tf2onnx.tf_loader import find_function
22-
from tf2onnx.graph_builder import GraphBuilder
2322

2423

2524
logger = logging.getLogger(__name__)
@@ -286,12 +285,7 @@ class TensorListGetItem:
286285
def version_7(cls, ctx, node, **kwargs):
287286
ctx.ta_reads.append(node.input[0])
288287
node.type = "Gather"
289-
g = GraphBuilder(ctx)
290-
291-
usq_node = g.make_unsqueeze({"axes": [0], 'data': node.input[1]}, name=node.child_name(), return_node=True)
292-
ctx.replace_inputs(node, [node.input[0], usq_node.output[0]])
293-
sq_node = g.make_squeeze({"axes": [0], 'data': node.output[0]}, name=node.child_name(), return_node=True)
294-
ctx.insert_node_on_output(sq_node)
288+
ctx.replace_inputs(node, [node.input[0], node.input[1]])
295289

296290
@classmethod
297291
def version_13(cls, ctx, node, **kwargs):

0 commit comments

Comments
 (0)