Skip to content

Commit d72b4d1

Browse files
hwangdeyuguschmuefatcat-z
authored
Improve ZerosLike implementation and optimize for opset >= 9 (#2003)
* Improve ZerosLike implementation for opset >= 9 Signed-off-by: Deyu Huang <[email protected]> Co-authored-by: Guenther Schmuelling <[email protected]> * add a blank line Signed-off-by: Deyu Huang <[email protected]> Co-authored-by: Guenther Schmuelling <[email protected]> Co-authored-by: Jay Zhang <[email protected]>
1 parent 1c7d4ce commit d72b4d1

File tree

4 files changed

+115
-1
lines changed

4 files changed

+115
-1
lines changed

tests/test_backend.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3887,6 +3887,19 @@ def func(x, y):
38873887

38883888
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x > 0.5, _INPUT1: input_y})
38893889

3890+
@check_opset_min_version(9, "ConstantOfShape")
3891+
def test_zeros_like_opset9(self):
3892+
input_x = np.random.random_sample([3, 16, 16]).astype(np.float32)
3893+
input_y = np.array([16, 16, 3]).astype(np.int64)
3894+
3895+
def func(x, y):
3896+
z = tf.reshape(x, y)
3897+
return tf.zeros_like(z, name=_TFOUTPUT)
3898+
3899+
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x, _INPUT1: input_y})
3900+
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x.astype(np.int32), _INPUT1: input_y}, as_session=True,
3901+
graph_validator=lambda g: check_op_count(g, "ConstantOfShape", 1, disabled=False))
3902+
38903903
@check_opset_min_version(9, "is_nan")
38913904
def test_isnan(self):
38923905
# only compatible with dtype `float32`

tests/test_optimizers.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2241,6 +2241,72 @@ def test_const_fold_cast_with_const(self):
22412241
self.run_and_compare(["res"], {"X": np.random.randn(*shape).astype(np.int64)}, model_proto,
22422242
"Cast", 0)
22432243

2244+
def test_const_fold_add(self):
2245+
shape = (6, 6)
2246+
const_tensor1 = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
2247+
vals=np.random.randn(*shape).flatten().astype(np.float32))
2248+
const_tensor2 = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
2249+
vals=np.random.randn(*shape).flatten().astype(np.float32))
2250+
node1 = helper.make_node("Constant", [], ["const1"], value=const_tensor1)
2251+
node2 = helper.make_node("Constant", [], ["const2"], value=const_tensor2)
2252+
node3 = helper.make_node("Add", ["const1", "const2"], ["add"])
2253+
node4 = helper.make_node("Add", ["add", "X"], ["res"])
2254+
2255+
graph = helper.make_graph(
2256+
[node1, node2, node3, node4],
2257+
"test_const_fold_add",
2258+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, shape)],
2259+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, shape)],
2260+
)
2261+
2262+
model_proto = self.make_model(graph, producer_name="onnx-tests")
2263+
self.run_and_compare(["res"], {"X": np.random.randn(*shape).astype(np.float32)}, model_proto,
2264+
"Add", 1)
2265+
2266+
def test_const_fold_sub(self):
2267+
shape = (6, 6)
2268+
const_tensor1 = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
2269+
vals=np.random.randn(*shape).flatten().astype(np.float32))
2270+
const_tensor2 = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
2271+
vals=np.random.randn(*shape).flatten().astype(np.float32))
2272+
node1 = helper.make_node("Constant", [], ["const1"], value=const_tensor1)
2273+
node2 = helper.make_node("Constant", [], ["const2"], value=const_tensor2)
2274+
node3 = helper.make_node("Sub", ["const1", "const2"], ["sub"])
2275+
node4 = helper.make_node("Sub", ["sub", "X"], ["res"])
2276+
2277+
graph = helper.make_graph(
2278+
[node1, node2, node3, node4],
2279+
"test_const_fold_sub",
2280+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, shape)],
2281+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, shape)],
2282+
)
2283+
2284+
model_proto = self.make_model(graph, producer_name="onnx-tests")
2285+
self.run_and_compare(["res"], {"X": np.random.randn(*shape).astype(np.float32)}, model_proto,
2286+
"Sub", 1)
2287+
2288+
def test_const_fold_mul(self):
2289+
shape = (6, 6)
2290+
const_tensor1 = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
2291+
vals=np.random.randn(*shape).flatten().astype(np.float32))
2292+
const_tensor2 = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
2293+
vals=np.random.randn(*shape).flatten().astype(np.float32))
2294+
node1 = helper.make_node("Constant", [], ["const1"], value=const_tensor1)
2295+
node2 = helper.make_node("Constant", [], ["const2"], value=const_tensor2)
2296+
node3 = helper.make_node("Mul", ["const1", "const2"], ["mul"])
2297+
node4 = helper.make_node("Mul", ["mul", "X"], ["res"])
2298+
2299+
graph = helper.make_graph(
2300+
[node1, node2, node3, node4],
2301+
"test_const_fold_mul",
2302+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, shape)],
2303+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, shape)],
2304+
)
2305+
2306+
model_proto = self.make_model(graph, producer_name="onnx-tests")
2307+
self.run_and_compare(["res"], {"X": np.random.randn(*shape).astype(np.float32)}, model_proto,
2308+
"Mul", 1)
2309+
22442310
def test_const_fold_split(self):
22452311
shape = (2, 6, 1)
22462312
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,

tf2onnx/onnx_opset/generator.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import logging
99

1010
import numpy as np
11-
from onnx import onnx_pb, numpy_helper
11+
from onnx import onnx_pb, numpy_helper, helper
1212
from tf2onnx import utils
1313
from tf2onnx.handler import tf_op
1414
from tf2onnx.graph_builder import GraphBuilder
@@ -242,6 +242,17 @@ def version_1(cls, ctx, node, **kwargs):
242242
name=node.name, outputs=node.output,
243243
shapes=shapes, dtypes=dtypes)
244244

245+
@classmethod
246+
def version_9(cls, ctx, node, **kwargs):
247+
dtypes = node.output_dtypes
248+
ctx.remove_node(node.name)
249+
shape = ctx.make_node("Shape", node.input).output[0]
250+
zero_tensor = helper.make_tensor("value", dtypes[0], [1], vals=[0])
251+
ctx.make_node("ConstantOfShape", inputs=[shape],
252+
attr={'value': zero_tensor},
253+
name=node.name, outputs=node.output,
254+
dtypes=dtypes)
255+
245256

246257
@tf_op(["IteratorV2", "FIFOQueueV2"])
247258
class Iterator:

tf2onnx/optimizer/const_fold_optimizer.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,30 @@ def _fold_unsqueeze(node, graph):
162162
const_val_after_unsqueeze = const_val.reshape(shape_out)
163163
return [const_val_after_unsqueeze]
164164

165+
@staticmethod
166+
@_register_func("Mul")
167+
def _fold_mul(node, graph):
168+
const_val1 = node.inputs[0].get_tensor_value(as_list=False)
169+
const_val2 = node.inputs[1].get_tensor_value(as_list=False)
170+
const_val_after_nul = np.multiply(const_val1, const_val2)
171+
return [const_val_after_nul]
172+
173+
@staticmethod
174+
@_register_func("Add")
175+
def _fold_add(node, graph):
176+
const_val1 = node.inputs[0].get_tensor_value(as_list=False)
177+
const_val2 = node.inputs[1].get_tensor_value(as_list=False)
178+
const_val_after_add = np.add(const_val1, const_val2)
179+
return [const_val_after_add]
180+
181+
@staticmethod
182+
@_register_func("Sub")
183+
def _fold_sub(node, graph):
184+
const_val1 = node.inputs[0].get_tensor_value(as_list=False)
185+
const_val2 = node.inputs[1].get_tensor_value(as_list=False)
186+
const_val_after_sub = np.subtract(const_val1, const_val2)
187+
return [const_val_after_sub]
188+
165189
@staticmethod
166190
@_register_func("Split")
167191
def _fold_split(node, graph):

0 commit comments

Comments
 (0)