Skip to content

Commit 3c69da2

Browse files
committed
ignore attr
1 parent da45397 commit 3c69da2

File tree

3 files changed

+10
-22
lines changed

3 files changed

+10
-22
lines changed

tf2onnx/onnx_opset/controlflow.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -377,8 +377,10 @@ class IteratorGetNext:
377377
@classmethod
378378
def version_8(cls, ctx, node, **kwargs):
379379
output_names = node.output
380+
type_0 = ctx.get_dtype(output_names[0])
381+
type_1 = ctx.get_dtype(output_names[1])
382+
shape_0 = ctx.get_shape(output_names[0])
383+
shape_1 = ctx.get_shape(output_names[1])
380384
ctx.remove_node(node.name)
381-
output_types = list(node.get_attr('output_types').ints)
382-
output_shapes = list(node.get_attr('output_shapes').ints)
383-
ctx.add_graph_input(output_names[0], output_types[0], output_shapes[:output_shapes.index(0)])
384-
ctx.add_graph_input(output_names[1], output_types[1], output_shapes[output_shapes.index(0)+1:-1])
385+
ctx.add_graph_input(output_names[0], type_0, shape_0)
386+
ctx.add_graph_input(output_names[1], type_1, shape_1)

tf2onnx/tfonnx.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def tflist_to_onnx(node_list, shape_override):
4646
ignored_attr = ["unknown_rank", "_class", "Tshape", "use_cudnn_on_gpu", "Index", "Tpaddings",
4747
"TI", "Tparams", "Tindices", "Tlen", "Tdim", "dynamic_size", "Tmultiples",
4848
"Tblock_shape", "Tcrops", "index_type", "Taxis", "U", "maxval",
49-
"Tout", "Tlabels", "Tindex", "element_shape", "Targmax", "T_threshold"]
49+
"Tout", "Tlabels", "Tindex", "element_shape", "Targmax", "T_threshold",
50+
"output_types", "output_shapes", "key_dtype", "value_dtype", "Tin", "Tout"]
51+
5052
# some stats
5153
op_cnt = collections.Counter()
5254
attr_cnt = collections.Counter()
@@ -80,8 +82,7 @@ def tflist_to_onnx(node_list, shape_override):
8082
if dtype:
8183
if not isinstance(dtype, list):
8284
dtypes[node.name] = utils.map_tf_dtype(dtype)
83-
elif a in ["output_type", "output_dtype", "out_type", "Tidx",
84-
"out_idx", "key_dtype", "value_dtype", "Tin", "Tout"]:
85+
elif a in ["output_type", "output_dtype", "out_type", "Tidx", "out_idx"]:
8586
# Tidx is used by Range
8687
# out_idx is used by ListDiff
8788
attr[a] = utils.map_tf_dtype(utils.get_tf_node_attr(node, a))
@@ -100,10 +101,6 @@ def tflist_to_onnx(node_list, shape_override):
100101
continue
101102
elif a in ignored_attr:
102103
continue
103-
elif a == "output_types":
104-
attr[a] = [utils.map_tf_dtype(v) for v in utils.get_tf_node_attr(node, a)]
105-
elif a == "output_shapes":
106-
attr[a] = utils.get_tf_output_shapes_attr(node)
107104
else:
108105
attr[a] = utils.get_tf_node_attr(node, a)
109106

tf2onnx/utils.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -177,17 +177,6 @@ def get_tf_shape_attr(node):
177177
pass
178178
return dims
179179

180-
def get_tf_output_shapes_attr(node):
181-
"""Get output shapes from tensorflow attr "output_shapes"."""
182-
dims = []
183-
try:
184-
shapes = get_tf_node_attr(node, "output_shapes")
185-
for shape in shapes:
186-
dims.extend([d.size for d in shape.dim])
187-
dims.append(0)
188-
except: # pylint: disable=bare-except
189-
pass
190-
return dims
191180

192181
def get_tf_tensor_shape(tensor):
193182
shape = []

0 commit comments

Comments
 (0)