Skip to content

Commit 0de86ee

Browse files
committed
enhance const_fold_optimizer: supports to fold const with Cast
1 parent 6ab3f7c commit 0de86ee

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

tests/test_optimizers.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,25 @@ def test_const_fold_unsqueeze_with_const(self):
503503
model_proto = helper.make_model(graph, producer_name="onnx-tests")
504504
self.run_and_compare(["res"], {"X": np.random.randn(1).astype(np.float32)}, model_proto,
505505
"Unsqueeze", 0)
506+
507+
def test_const_fold_cast_with_const(self):
508+
shape = (6, 6)
509+
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
510+
vals=np.random.randn(*shape).flatten().astype(np.float32))
511+
node1 = helper.make_node("Constant", [], ["const"], value=const_tensor)
512+
node2 = helper.make_node("Cast", ["const"], ["value1"], to=TensorProto.INT64)
513+
node3 = helper.make_node("Add", ["value1", "X"], ["res"])
514+
515+
graph = helper.make_graph(
516+
[node1, node2, node3],
517+
"test_const_fold_cast_with_const",
518+
[helper.make_tensor_value_info("X", TensorProto.INT64, shape)],
519+
[helper.make_tensor_value_info("res", TensorProto.INT64, shape)],
520+
)
521+
522+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
523+
self.run_and_compare(["res"], {"X": np.random.randn(*shape).astype(np.int64)}, model_proto,
524+
"Cast", 0)
506525
# Const Fold Optimizer Tests End
507526

508527

tf2onnx/optimizer/const_fold_optimizer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,14 @@ def _replace_node_with_const(node, graph, vals):
9292
graph.replace_all_inputs(graph.get_nodes(), old_input, const_node.output[0])
9393
graph.remove_node(node.name)
9494

95+
@staticmethod
96+
@_register_func("Cast")
97+
def _fold_cast(node, graph):
98+
const_val = node.inputs[0].get_tensor_value(as_list=False)
99+
np_dtype = utils.ONNX_TO_NUMPY_DTYPE[node.get_attr("to").i]
100+
const_val_after_cast = const_val.astype(np_dtype)
101+
return [const_val_after_cast]
102+
95103
@staticmethod
96104
@_register_func("Transpose")
97105
def _fold_transpose(node, graph) -> list:

0 commit comments

Comments
 (0)