Skip to content

Commit 9c2b765

Browse files
authored
Merge pull request #501 from zhijxu-MS/tmp_branch_for_PR2
enchace fold_const_optimizer
2 parents a27312a + 9bbd3c4 commit 9c2b765

File tree

3 files changed

+71
-0
lines changed

3 files changed

+71
-0
lines changed

tests/test_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,6 +1417,7 @@ def test_cancel_transpose(self):
14171417
_ = tf.identity(x_, name=_TFOUTPUT)
14181418
self._run_test_case([_OUTPUT], {_INPUT: x_val})
14191419

1420+
@check_onnxruntime_min_version("0.5.0", "topk-10's shape inference function has a bug")
14201421
@check_opset_min_version(6, "cast")
14211422
def test_topk1(self):
14221423
x_val = np.arange(3 * 2 * 3).astype("float32")

tests/test_optimizers.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,44 @@ def test_const_fold_node_is_output(self):
484484
model_proto = helper.make_model(graph, producer_name="onnx-tests")
485485
self.run_transpose_compare(["res"], {},
486486
model_proto, remaining_transpose_num=0)
487+
488+
def test_const_fold_unsqueeze_with_const(self):
489+
shape = (6, 6)
490+
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
491+
vals=np.random.randn(*shape).flatten().astype(np.float32))
492+
node1 = helper.make_node("Constant", [], ["const"], value=const_tensor)
493+
node2 = helper.make_node("Unsqueeze", ["const"], ["value1"], axes=[0, 2, 3])
494+
node3 = helper.make_node("Add", ["value1", "X"], ["res"])
495+
496+
graph = helper.make_graph(
497+
[node1, node2, node3],
498+
"test_const_fold_unsqueeze_with_const",
499+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (1,))],
500+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, (1, 6, 1, 1, 6))],
501+
)
502+
503+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
504+
self.run_and_compare(["res"], {"X": np.random.randn(1).astype(np.float32)}, model_proto,
505+
"Unsqueeze", 0)
506+
507+
def test_const_fold_cast_with_const(self):
508+
shape = (6, 6)
509+
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
510+
vals=np.random.randn(*shape).flatten().astype(np.float32))
511+
node1 = helper.make_node("Constant", [], ["const"], value=const_tensor)
512+
node2 = helper.make_node("Cast", ["const"], ["value1"], to=TensorProto.INT64)
513+
node3 = helper.make_node("Add", ["value1", "X"], ["res"])
514+
515+
graph = helper.make_graph(
516+
[node1, node2, node3],
517+
"test_const_fold_cast_with_const",
518+
[helper.make_tensor_value_info("X", TensorProto.INT64, shape)],
519+
[helper.make_tensor_value_info("res", TensorProto.INT64, shape)],
520+
)
521+
522+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
523+
self.run_and_compare(["res"], {"X": np.random.randn(*shape).astype(np.int64)}, model_proto,
524+
"Cast", 0)
487525
# Const Fold Optimizer Tests End
488526

489527

tf2onnx/optimizer/const_fold_optimizer.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,14 @@ def _replace_node_with_const(node, graph, vals):
9292
graph.replace_all_inputs(graph.get_nodes(), old_input, const_node.output[0])
9393
graph.remove_node(node.name)
9494

95+
@staticmethod
96+
@_register_func("Cast")
97+
def _fold_cast(node, graph):
98+
const_val = node.inputs[0].get_tensor_value(as_list=False)
99+
np_dtype = utils.ONNX_TO_NUMPY_DTYPE[node.get_attr("to").i]
100+
const_val_after_cast = const_val.astype(np_dtype)
101+
return [const_val_after_cast]
102+
95103
@staticmethod
96104
@_register_func("Transpose")
97105
def _fold_transpose(node, graph) -> list:
@@ -100,3 +108,27 @@ def _fold_transpose(node, graph) -> list:
100108
perm = perm_attr.ints if perm_attr else None
101109
const_val_after_trans = const_val.transpose(perm)
102110
return [const_val_after_trans]
111+
112+
@staticmethod
113+
@_register_func("Unsqueeze")
114+
def _fold_unsqueeze(node, graph):
115+
"""
116+
numpy expand_dims only supports to unsqueeze one dim one time, so reshape is used to simplify the logic
117+
"""
118+
const_val = node.inputs[0].get_tensor_value(as_list=False)
119+
axes = list(node.get_attr("axes").ints)
120+
utils.make_sure(all(axis >= 0 for axis in axes), "onnx spec says it only supports positive axis")
121+
shape_in = const_val.shape
122+
dims_out = len(shape_in) + len(axes)
123+
# calculate the shape of output accroding to onnx Unsqueeze's spec
124+
# https://github.com/onnx/onnx/blob/master/docs/Operators.md#Unsqueeze
125+
shape_in = iter(shape_in)
126+
shape_out = [None] * dims_out
127+
for ind in axes:
128+
shape_out[ind] = 1
129+
for ind, val in enumerate(shape_out):
130+
if val is None:
131+
shape_out[ind] = next(shape_in)
132+
133+
const_val_after_unsqueeze = const_val.reshape(shape_out)
134+
return [const_val_after_unsqueeze]

0 commit comments

Comments
 (0)