Skip to content

Commit e795995

Browse files
committed
workaround for the bug of tf fold_const
will add a fold_const feature later.
1 parent b5ee2f1 commit e795995

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

tf2onnx/tfonnx.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2266,6 +2266,28 @@ def rewrite_conv2d_with_pad(g, ops):
22662266
return ops
22672267

22682268

2269+
def rewrite_const_sub_with_pack(g, ops):
2270+
# slice op needs "begin" and "size" are const while tf fold_const can't fold const_sub with pack
2271+
pattern = \
2272+
OpTypePattern("Pack", name="pack", inputs=[
2273+
OpTypePattern("Sub", name="sub", inputs=[
2274+
OpTypePattern("Const"),
2275+
OpTypePattern("Const")
2276+
])
2277+
])
2278+
matcher = GraphMatcher(pattern)
2279+
match_results = list(matcher.match_ops(ops))
2280+
for match in match_results:
2281+
sub = match.get_op("sub")
2282+
sub_res = sub.inputs[0].get_tensor_value() - sub.inputs[1].get_tensor_value()
2283+
utils.make_sure(isinstance(sub_res, (int, float)), "pack input here should be a scalar")
2284+
pack = match.get_op("pack")
2285+
np_val = np.array([sub_res]).astype(utils.map_onnx_to_numpy_type(g.get_dtype(pack.output[0])))
2286+
const = g.make_const(utils.make_name("const_val"), np_val)
2287+
g.replace_all_inputs(ops, pack.output[0], const.output[0])
2288+
return g.get_nodes()
2289+
2290+
22692291
def tensorflow_onnx_mapping(g, continue_on_error, custom_op_handlers):
22702292
mapped_op = collections.Counter()
22712293
unmapped_op = collections.Counter()
@@ -2472,7 +2494,7 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
24722494

24732495
# pre-processing graph rewrites
24742496
# bi-directional re-writer should be placed after single directional re-writer
2475-
rewriters = [rewrite_transpose, rewrite_flatten,
2497+
rewriters = [rewrite_transpose, rewrite_flatten, rewrite_const_sub_with_pack,
24762498
rewrite_random_uniform, rewrite_random_uniform_fold_const,
24772499
rewrite_random_normal, rewrite_dropout,
24782500
rewrite_leakyrelu, rewrite_conv2d_with_pad,

0 commit comments

Comments
 (0)