@@ -2266,6 +2266,28 @@ def rewrite_conv2d_with_pad(g, ops):
2266
2266
return ops
2267
2267
2268
2268
2269
+ def rewrite_const_sub_with_pack (g , ops ):
2270
+ # slice op needs "begin" and "size" are const while tf fold_const can't fold const_sub with pack
2271
+ pattern = \
2272
+ OpTypePattern ("Pack" , name = "pack" , inputs = [
2273
+ OpTypePattern ("Sub" , name = "sub" , inputs = [
2274
+ OpTypePattern ("Const" ),
2275
+ OpTypePattern ("Const" )
2276
+ ])
2277
+ ])
2278
+ matcher = GraphMatcher (pattern )
2279
+ match_results = list (matcher .match_ops (ops ))
2280
+ for match in match_results :
2281
+ sub = match .get_op ("sub" )
2282
+ sub_res = sub .inputs [0 ].get_tensor_value () - sub .inputs [1 ].get_tensor_value ()
2283
+ utils .make_sure (isinstance (sub_res , (int , float )), "pack input here should be a scalar" )
2284
+ pack = match .get_op ("pack" )
2285
+ np_val = np .array ([sub_res ]).astype (utils .map_onnx_to_numpy_type (g .get_dtype (pack .output [0 ])))
2286
+ const = g .make_const (utils .make_name ("const_val" ), np_val )
2287
+ g .replace_all_inputs (ops , pack .output [0 ], const .output [0 ])
2288
+ return g .get_nodes ()
2289
+
2290
+
2269
2291
def tensorflow_onnx_mapping (g , continue_on_error , custom_op_handlers ):
2270
2292
mapped_op = collections .Counter ()
2271
2293
unmapped_op = collections .Counter ()
@@ -2472,7 +2494,7 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
2472
2494
2473
2495
# pre-processing graph rewrites
2474
2496
# bi-directional re-writer should be placed after single directional re-writer
2475
- rewriters = [rewrite_transpose , rewrite_flatten ,
2497
+ rewriters = [rewrite_transpose , rewrite_flatten , rewrite_const_sub_with_pack ,
2476
2498
rewrite_random_uniform , rewrite_random_uniform_fold_const ,
2477
2499
rewrite_random_normal , rewrite_dropout ,
2478
2500
rewrite_leakyrelu , rewrite_conv2d_with_pad ,
0 commit comments