Skip to content

Commit f0d9a4d

Browse files
Fix freezing of ResourceGather op (#1672)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 2be4cf3 commit f0d9a4d

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

tf2onnx/tf_loader.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,23 @@ def fix_freezing_errors(graph_def):
193193
return graph_def
194194

195195

196+
def fix_freezing_errors_part2(graph_def):
197+
# Sometimes tf freezing fails to convert ResourceGather ops in subgraphs
198+
for f in graph_def.library.function:
199+
for n in f.node_def:
200+
if n.op == "ResourceGather":
201+
# Convert to standard Gather op. Freezing will have replaced resource with constant.
202+
# Needed because of: https://github.com/tensorflow/tensorflow/issues/51488
203+
n.op = "Gather"
204+
n.attr['Tparams'].type = n.attr['dtype'].type
205+
del n.attr['dtype']
206+
if 'batch_dims' in n.attr:
207+
v = n.attr['batch_dims'].i
208+
utils.make_sure(v == 0, "Unsupported batch_dims value of ResourceGather %d", v)
209+
del n.attr['batch_dims']
210+
return graph_def
211+
212+
196213
def from_trackable(trackable, concrete_func, inputs, outputs, large_model):
197214
err_large_model = "model exceeds maximum protobuf size of 2GB. Try setting large_model."
198215

@@ -263,6 +280,7 @@ def from_function(func, input_names, output_names, large_model=False):
263280
raise e
264281
else:
265282
graph_def = frozen_func.graph.as_graph_def(add_shapes=True)
283+
graph_def = fix_freezing_errors_part2(graph_def)
266284

267285
# output_names = [i.name for i in frozen_func.outputs]
268286
with tf.Graph().as_default() as tf_graph:

0 commit comments

Comments
 (0)