Skip to content

Commit 0fb855d

Browse files
committed
Merge remote-tracking branch 'upstream/master' into jignparm/optimize_transpose_multiply
2 parents 37be39d + f482810 commit 0fb855d

File tree

9 files changed

+102
-29
lines changed

9 files changed

+102
-29
lines changed

tests/test_backend.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1787,6 +1787,29 @@ def func(x, y):
17871787
return tf.identity(x_, name=_TFOUTPUT)
17881788
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
17891789

1790+
@check_opset_min_version(10, "Slice")
1791+
def test_strided_slice_reverse_1(self):
1792+
x_val = np.arange(16 * 32).astype(np.float32).reshape((1, 16, 32, 1))
1793+
def func(x):
1794+
return tf.concat([x[:, :, :10], x[:, :, :21:-1]], axis=0, name=_TFOUTPUT)
1795+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
1796+
1797+
@check_opset_min_version(10, "Slice")
1798+
def test_strided_slice_reverse_2(self):
1799+
x_val = np.arange(16 * 32).astype(np.float32).reshape((1, 16, 32, 1))
1800+
def func(x):
1801+
return tf.concat([x[:, :, :10], x[:, :, 9::-1]], axis=0, name=_TFOUTPUT)
1802+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
1803+
1804+
@check_opset_min_version(10, "Slice")
1805+
def test_strided_slice_reverse_3(self):
1806+
x_val = np.zeros((1, 16, 32, 1)).astype(np.float32)
1807+
y_val = np.array(9).astype(np.int32)
1808+
z_val = np.array(-1).astype(np.int32)
1809+
def func(x, y, z):
1810+
return tf.concat([x[:, :, :10], x[:, :, y::z]], axis=0, name=_TFOUTPUT)
1811+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, _INPUT2: z_val})
1812+
17901813
@check_opset_min_version(10, "Slice")
17911814
def test_new_axis_mask(self):
17921815
def func(x, y):

tf2onnx/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,9 @@
3737

3838
# Environment variables
3939
ENV_TF2ONNX_DEBUG_MODE = "TF2ONNX_DEBUG_MODE"
40+
41+
# Mapping opset to IR version.
42+
# When adding here, make sure that the IR changes don't impact that we do.
43+
OPSET_TO_IR_VERSION = {
44+
1: 3, 2:3, 3: 3, 4: 4, 5: 3, 6:3, 7:3, 8:4, 9:4, 10:5, 11:6, 12:7
45+
}

tf2onnx/custom_opsets/ms.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,14 @@ def version_1(cls, ctx, node, **kwargs):
9292
node.attr.pop("padding")
9393
if "explicit_paddings" in node.attr:
9494
node.attr.pop("explicit_paddings")
95+
96+
@tf_op("CropAndResize", domain=constants.MICROSOFT_DOMAIN)
97+
class CropAndResize:
98+
@classmethod
99+
def version_11(cls, ctx, node, **kwargs):
100+
""" utilize contrib cropandresize """
101+
node.attr['method'].name = 'mode'
102+
node.domain = constants.MICROSOFT_DOMAIN
103+
ctx.insert_new_node_on_input(node, "Transpose", node.input[0], perm=constants.NHWC_TO_NCHW)
104+
ctx.insert_new_node_on_output("Transpose", node.output[0], node.name + '_transposed',
105+
None, perm=constants.NCHW_TO_NHWC)

tf2onnx/graph.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,6 +1039,12 @@ def make_model(self, graph_doc, optimize=False, graph_name="tf2onnx", **kwargs):
10391039
kwargs["opset_imports"] = opsets
10401040
model_proto = helper.make_model(graph, **kwargs)
10411041

1042+
# set the IR version based on opset
1043+
try:
1044+
model_proto.ir_version = constants.OPSET_TO_IR_VERSION.get(self.opset, model_proto.ir_version)
1045+
except:
1046+
logger.error("ir_version override failed - install the latest onnx version")
1047+
10421048
# optimize the model proto.
10431049
# TODO: this is disabled by default because of bugs in fuse_consecutive_transposes
10441050
if optimize:

tf2onnx/onnx_opset/controlflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,14 +372,14 @@ def version_9(cls, ctx, node, **kwargs):
372372
ctx.copy_dtype(node.output[0], transpose_node.output[0])
373373

374374

375-
@tf_op("IteratorV2")
375+
@tf_op("IteratorV2", "FIFOQueueV2")
376376
class Iterator:
377377
@classmethod
378378
def version_8(cls, ctx, node, **kwargs):
379379
ctx.remove_node(node.name)
380380

381381

382-
@tf_op("IteratorGetNext")
382+
@tf_op("IteratorGetNext", "QueueDequeueV2")
383383
class IteratorGetNext:
384384
@classmethod
385385
def version_8(cls, ctx, node, **kwargs):

tf2onnx/onnx_opset/math.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ def make_min_or_max_op(ctx, op_type, inputs, outputs,
9797
ctx.set_dtype(cast_node.output[0], origin_dtype)
9898
ctx.copy_shape(node.output[0], cast_node.output[0])
9999
actual_outputs = cast_node.output
100-
ctx.make_node("Identity", actual_outputs, outputs=outputs,
101-
shapes=output_shapes, dtypes=output_dtypes)
100+
final_node = ctx.make_node("Identity", actual_outputs, outputs=outputs,
101+
shapes=output_shapes, dtypes=output_dtypes)
102102

103103
# tensorflow minimum/maximum does support broadcast, onnx < opset 8 does not.
104104
# handle this by doing something like:
@@ -124,6 +124,7 @@ def make_min_or_max_op(ctx, op_type, inputs, outputs,
124124
add_node = ctx.make_node("Add", [input_node.output[0], sub_node.output[0]],
125125
op_name_scope=input_node.name)
126126
node.input[i] = add_node.output[0]
127+
return final_node
127128

128129

129130
@tf_op("Minimum", onnx_op="Min")

tf2onnx/onnx_opset/tensor.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -845,37 +845,51 @@ def version_10(cls, ctx, node, **kwargs):
845845
# mask begin
846846
new_begin_mask = np.array(new_begin_mask, dtype=np_dtype)
847847
if not np.all(new_begin_mask == 1):
848-
if begin.is_const():
849-
begin = ctx.make_const(
850-
utils.make_name("begin_masked"),
851-
begin.get_tensor_value(as_list=False) * new_begin_mask
852-
)
848+
if begin.is_const() and strides.is_const():
849+
new_begin_vals = np.copy(begin.get_tensor_value(as_list=False))
850+
strides_vals = strides.get_tensor_value(as_list=False)
851+
idx1 = np.where(new_begin_mask == 0)
852+
idx2 = np.where(strides_vals < 0)
853+
idx3 = np.intersect1d(idx1, idx2)
854+
new_begin_vals[idx3] = max_size
855+
begin = ctx.make_const(utils.make_name("begin_masked"), new_begin_vals)
853856
else:
854-
begin_mask_const = ctx.make_const(
855-
utils.make_name("begin_mask"),
856-
new_begin_mask
857-
)
858-
begin = ctx.make_node(
859-
"Mul", [begin.output[0], begin_mask_const.output[0]],
860-
op_name_scope=node.name
861-
)
857+
begin_mask_const = ctx.make_const(utils.make_name("begin_mask"), np.equal(new_begin_mask, 0))
858+
zero_const = ctx.make_const(utils.make_name("zero_const"), np.zeros(1, dtype=np_dtype))
859+
max_const = ctx.make_const(utils.make_name("max_const"), np.array(max_size, dtype=np_dtype))
860+
op1 = ctx.make_node("Less", [strides.output[0], zero_const.output[0]], op_name_scope=node.name)
861+
op2 = ctx.make_node("And", [op1.output[0], begin_mask_const.output[0]], op_name_scope=node.name)
862+
begin = ctx.make_node("Where", [op2.output[0], max_const.output[0], begin.output[0]],
863+
op_name_scope=node.name)
864+
862865
# mask end
863866
new_end_mask = np.array(new_end_mask, dtype=np_dtype)
864867
end_output = end.output[0]
865868
if not np.all(new_end_mask == min_size):
866-
if end.is_const():
867-
end = ctx.make_const(
868-
utils.make_name("end_masked"),
869-
np.maximum(end.get_tensor_value(as_list=False), new_end_mask)
870-
)
869+
if end.is_const() and strides.is_const():
870+
new_end_mask = np.maximum(end.get_tensor_value(as_list=False), new_end_mask)
871+
idx = np.where(new_end_mask == max_size)
872+
sign = np.sign(strides.get_tensor_value(as_list=False))[idx]
873+
new_end_mask[idx] = new_end_mask[idx] * sign
874+
end = ctx.make_const(utils.make_name("end_masked"), new_end_mask)
871875
end_output = end.output[0]
872876
else:
873-
end_mask_const = ctx.make_const(
874-
utils.make_name("end_mask"),
875-
np.array(new_end_mask, dtype=np_dtype)
876-
)
877-
end_output = utils.make_name("{}__end".format(node.name))
878-
math.make_min_or_max_op(ctx, "Max", [end.output[0], end_mask_const.output[0]], [end_output])
877+
# Overlay new_end_mask with specified end values.
878+
# Adjust max_size to min_size if steps are < 0
879+
max_const = ctx.make_const(utils.make_name("max_const"), np.array(max_size, dtype=np_dtype))
880+
min_const = ctx.make_const(utils.make_name("min_const"), np.array(min_size, dtype=np_dtype))
881+
zero_const = ctx.make_const(utils.make_name("zero_const"), np.zeros(1, dtype=np_dtype))
882+
end_mask_const = ctx.make_const(utils.make_name("end_mask"), np.array(new_end_mask, dtype=np_dtype))
883+
outputname = utils.make_name("{}__newendmask".format(node.name))
884+
new_end_mask = math.make_min_or_max_op(ctx, "Max", [end.output[0], end_mask_const.output[0]],
885+
[outputname])
886+
op1 = ctx.make_node("Less", [strides.output[0], zero_const.output[0]], op_name_scope=node.name)
887+
op2 = ctx.make_node("Equal", [new_end_mask.output[0], max_const.output[0]], op_name_scope=node.name)
888+
op3 = ctx.make_node("And", [op2.output[0], op1.output[0]], op_name_scope=node.name)
889+
final_end = ctx.make_node("Where", [op3.output[0], min_const.output[0],
890+
new_end_mask.output[0]], op_name_scope=node.name)
891+
end_output = final_end.output[0]
892+
879893
# mask strides for shrink
880894
shrink_strided_mask = np.array(shrink_strided_mask, dtype=np_dtype)
881895
strides_output = strides.output[0]

tf2onnx/rewriter/gemm_rewriter.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,4 +99,16 @@ def get_gemm_attr(match):
9999
return attr, False
100100
match_args = match_args[0]
101101
attr[arg] = match_args
102+
for arg in ["matmul"]:
103+
arg_op = match.get_op(arg)
104+
if arg_op is not None:
105+
match_args = arg_op.attr
106+
if isinstance(match_args, dict):
107+
keys = list(match_args.keys())
108+
if 'transpose_a' not in keys and 'transpose_b' not in keys:
109+
return attr, False
110+
match_args_a = match_args['transpose_a'].i
111+
attr['transA'] = match_args_a
112+
match_args_b = match_args['transpose_b'].i
113+
attr['transB'] = match_args_b
102114
return attr, True

tf2onnx/tf_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def tflist_to_onnx(g, shape_override):
138138
"Tout", "Tlabels", "Tindex", "element_shape", "Targmax", "Tperm", "Tcond",
139139
"T_threshold", "element_dtype", "shape_type", "_lower_using_switch_merge",
140140
"parallel_iterations", "_num_original_outputs", "output_types", "output_shapes",
141-
"key_dtype", "value_dtype", "Tin", "Tout"]
141+
"key_dtype", "value_dtype", "Tin", "Tout", "capacity", "component_types", "shapes"]
142142

143143
node_list = g.get_operations()
144144
functions = {}

0 commit comments

Comments
 (0)