Skip to content

Commit c2eef98

Browse files
Merge pull request #890 from RandySheriffH/rashuai/MatrixDiagPartV3
Support MatrixDiagPart v2 and v3
2 parents aa0615e + 9cfa737 commit c2eef98

File tree

2 files changed

+236
-0
lines changed

2 files changed

+236
-0
lines changed

tests/test_backend.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3108,6 +3108,21 @@ def func(x, y):
31083108
return tf.cast(s_, tf.float32, name=_TFOUTPUT)
31093109
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
31103110

3111+
@check_opset_min_version(11)
3112+
@check_tf_min_version("2.2")
3113+
def test_matrix_diag_part_v3(self):
3114+
3115+
def func(X, K):
3116+
v2 = tf.raw_ops.MatrixDiagPartV2(input=X, k=K, padding_value=0.123, name=_TFOUTPUT)
3117+
v3 = tf.raw_ops.MatrixDiagPartV3(input=X, k=K, padding_value=0.123, align='RIGHT_LEFT', name=_TFOUTPUT1)
3118+
return v2, v3
3119+
3120+
for x_shape in ([4, 5], [2, 3, 4, 5]):
3121+
x_val = np.random.random(x_shape).astype(np.float32)
3122+
for raw_k in ([0], [1], [3], [-1], [-3], [1, 2], [-2, -1], [-1, 1]):
3123+
k_val = np.array(raw_k).astype(np.int32)
3124+
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val, _INPUT1: k_val})
3125+
31113126

31123127
if __name__ == '__main__':
31133128
unittest_main()

tf2onnx/onnx_opset/tensor.py

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1819,6 +1819,227 @@ def version_11(cls, ctx, node, **kwargs):
18191819
name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
18201820

18211821

1822+
@tf_op(["MatrixDiagPartV2", "MatrixDiagPartV3"])
1823+
class MatrixDiagPartV2V3:
1824+
@classmethod
1825+
def version_11(cls, ctx, node, **kwargs):
1826+
# assemble MatrixDiagPart V2&V3 by looping k diagonals with proper pads
1827+
input_tensor = node.input[0]
1828+
k = ctx.make_node('Cast', [node.input[1]], attr={'to': TensorProto.INT64}).output[0]
1829+
padding = node.input[2]
1830+
align = 'LEFT_LEFT'
1831+
if node.op.op_type == 'MatrixDiagPartV3':
1832+
align = node.get_attr_str('align') if 'align' in node.attr else 'LEFT_RIGHT'
1833+
input_rank = len(ctx.get_shape(input_tensor))
1834+
raw_input_shape = [-1] * input_rank
1835+
per_loop_shape = raw_input_shape[:-1]
1836+
raw_output_shape = raw_input_shape[:-2] + [-1]
1837+
loop_output_shape = raw_output_shape + [-1]
1838+
ctx.set_shape(node.output[0], raw_output_shape)
1839+
for out in ctx.find_output_consumers(node.output[0]):
1840+
if out.op.op_type == 'Identity':
1841+
ctx.set_shape(out.output[0], raw_output_shape)
1842+
# define constants
1843+
const_zero = ctx.make_const(utils.make_name(node.name) + 'const_zero', np.array([0]).astype(np.int64))
1844+
const_one = ctx.make_const(utils.make_name(node.name) + 'const_one', np.array([1]).astype(np.int64))
1845+
const_two = ctx.make_const(utils.make_name(node.name) + 'const_two', np.array([2]).astype(np.int64))
1846+
const_neg_one = ctx.make_const(utils.make_name(node.name) + 'const_neg_one', np.array([-1]).astype(np.int64))
1847+
const_neg_two = ctx.make_const(utils.make_name(node.name) + 'const_neg_two', np.array([-2]).astype(np.int64))
1848+
# prepare new_shape of input
1849+
input_shape = ctx.make_node('Shape', [input_tensor])
1850+
shape_input_shape = ctx.make_node('Shape', [input_shape.output[0]])
1851+
matrix_shape = ctx.make_node('Slice',
1852+
[input_shape.output[0], const_neg_two.output[0], shape_input_shape.output[0]])
1853+
min_dim = ctx.make_node('ReduceMin', [matrix_shape.output[0]])
1854+
input_depth = ctx.make_node('Slice', [matrix_shape.output[0], const_neg_two.output[0], const_neg_one.output[0]])
1855+
input_width = ctx.make_node('Slice', [matrix_shape.output[0], const_neg_one.output[0], const_two.output[0]])
1856+
temp_shape = ctx.make_node('Concat', [const_neg_one.output[0], matrix_shape.output[0]], attr={'axis': 0})
1857+
temp_input = ctx.make_node('Reshape', [input_tensor, temp_shape.output[0]])
1858+
temp_transposed = ctx.make_node('Transpose', [temp_input.output[0]], attr={'perm': [0, 2, 1]})
1859+
half_shape = ctx.make_node('Slice', [input_shape.output[0], const_zero.output[0], const_neg_two.output[0]])
1860+
new_shape = ctx.make_node('Concat', [half_shape.output[0], input_width.output[0], input_depth.output[0]],
1861+
attr={'axis': 0})
1862+
# define body graph for main loop
1863+
k_shape = ctx.make_node('Shape', [k])
1864+
k_start = ctx.make_node('Slice', [k, const_zero.output[0], const_one.output[0]])
1865+
k_end = ctx.make_node('Slice', [k, const_neg_one.output[0], k_shape.output[0]])
1866+
raw_total_k = ctx.make_node('Sub', [k_end.output[0], k_start.output[0]])
1867+
total_k = ctx.make_node('Add', [raw_total_k.output[0], const_one.output[0]])
1868+
trip_name = utils.make_name(node.name + "_i")
1869+
cond_name = utils.make_name(node.name + "_cond")
1870+
body_graph = ctx.create_new_graph_with_same_config()
1871+
body_graph.add_graph_input(trip_name, TensorProto.INT64, [1])
1872+
body_graph.add_graph_input(cond_name, TensorProto.BOOL, [])
1873+
body_graph.parent_graph = ctx
1874+
# identity of input
1875+
identity_input_graph = body_graph.create_new_graph_with_same_config()
1876+
identity_input_graph.parent_graph = body_graph
1877+
identity_input = identity_input_graph.make_node('Identity', [input_tensor])
1878+
identity_input_graph.add_graph_output(identity_input.output[0], ctx.get_dtype(node.input[0]), raw_input_shape)
1879+
# transposed input
1880+
transposed_input_graph = body_graph.create_new_graph_with_same_config()
1881+
transposed_input_graph.parent_graph = body_graph
1882+
next_shape = transposed_input_graph.make_node('Concat', [half_shape.output[0], input_width.output[0],
1883+
input_depth.output[0]], attr={'axis': 0})
1884+
transposed_input = transposed_input_graph.make_node('Reshape',
1885+
[temp_transposed.output[0], next_shape.output[0]])
1886+
transposed_input_graph.add_graph_output(transposed_input.output[0], ctx.get_dtype(node.input[0]),
1887+
raw_input_shape)
1888+
# compute current k of the loop
1889+
current_k = body_graph.make_node('Sub', [k_end.output[0], trip_name])
1890+
is_k_noneg = body_graph.make_node('Greater', [current_k.output[0], const_neg_one.output[0]])
1891+
processed_input = body_graph.make_node('If', [is_k_noneg.output[0]])
1892+
processed_input.set_body_graph_as_attr('then_branch', identity_input_graph)
1893+
processed_input.set_body_graph_as_attr('else_branch', transposed_input_graph)
1894+
processed_shape = body_graph.make_node('Shape', [processed_input.output[0]])
1895+
shape_processed_shape = body_graph.make_node('Shape', [processed_shape.output[0]])
1896+
new_depth = body_graph.make_node('Slice',
1897+
[processed_shape.output[0], const_neg_two.output[0], const_neg_one.output[0]])
1898+
new_width = body_graph.make_node('Slice', [processed_shape.output[0], const_neg_one.output[0],
1899+
shape_processed_shape.output[0]])
1900+
abs_k = body_graph.make_node('Abs', [current_k.output[0]])
1901+
range_k = body_graph.make_node('Range', [abs_k.output[0], new_width.output[0], const_one.output[0]],
1902+
domain="com.microsoft")
1903+
sliced_range = body_graph.make_node('Slice', [range_k.output[0], const_zero.output[0], new_depth.output[0]])
1904+
sliced_shape = body_graph.make_node('Shape', [sliced_range.output[0]])
1905+
pad_length = body_graph.make_node('Sub', [new_depth.output[0], sliced_shape.output[0]])
1906+
pad_length_2 = body_graph.make_node('Concat', [const_zero.output[0], pad_length.output[0]], attr={'axis': 0})
1907+
padded_range = body_graph.make_node('Pad', [sliced_range.output[0], pad_length_2.output[0]])
1908+
unsqueezed_range = body_graph.make_node('Unsqueeze', [padded_range.output[0]], attr={'axes': [1]})
1909+
half_shape_x = body_graph.make_node('Slice',
1910+
[new_shape.output[0], const_zero.output[0], const_neg_two.output[0]])
1911+
shape_range = body_graph.make_node('Shape', [unsqueezed_range.output[0]])
1912+
full_shape = body_graph.make_node('Concat', [half_shape_x.output[0], shape_range.output[0]], attr={'axis': 0})
1913+
expanded_range = body_graph.make_node('Expand', [unsqueezed_range.output[0], full_shape.output[0]])
1914+
gathered_input = body_graph.make_node('GatherElements', [processed_input.output[0], expanded_range.output[0]],
1915+
attr={'axis': -1})
1916+
squeezed_input = body_graph.make_node('Squeeze', [gathered_input.output[0]], attr={'axes': [-1]})
1917+
left_width = body_graph.make_node('Sub', [new_width.output[0], abs_k.output[0]])
1918+
dims = body_graph.make_node('Concat', [left_width.output[0], new_depth.output[0]], attr={'axis': 0})
1919+
valid_dim = body_graph.make_node('ReduceMin', [dims.output[0]])
1920+
raw_output = body_graph.make_node('Slice', [squeezed_input.output[0], const_zero.output[0], valid_dim.output[0],
1921+
const_neg_one.output[0]])
1922+
gap_output = body_graph.make_node('Sub', [min_dim.output[0], valid_dim.output[0]])
1923+
gaps = body_graph.make_node('Concat', [const_zero.output[0], gap_output.output[0]], attr={'axis': 0})
1924+
processed_gap = body_graph.make_node('ReduceMax', [gaps.output[0]])
1925+
pad_zero = body_graph.make_node('Mul', [new_shape.output[0], const_zero.output[0]])
1926+
sliced_zero = body_graph.make_node('Slice', [pad_zero.output[0], const_zero.output[0], const_neg_two.output[0]])
1927+
# gap_pos_k_graph
1928+
gap_pos_k_graph = body_graph.create_new_graph_with_same_config()
1929+
gap_pos_k_graph.parent_graph = body_graph
1930+
gap_pos_k = gap_pos_k_graph.make_node('Concat', [const_zero.output[0],
1931+
processed_gap.output[0]],
1932+
attr={'axis': 0}) \
1933+
if align.startswith('LEFT') \
1934+
else gap_pos_k_graph.make_node('Concat', [processed_gap.output[0],
1935+
const_zero.output[0]],
1936+
attr={'axis': 0})
1937+
gap_pos_k_graph.add_graph_output(gap_pos_k.output[0], TensorProto.INT64, [-1])
1938+
# gap_neg_k_graph
1939+
gap_neg_k_graph = body_graph.create_new_graph_with_same_config()
1940+
gap_neg_k_graph.parent_graph = body_graph
1941+
gap_neg_k = gap_neg_k_graph.make_node('Concat', [const_zero.output[0],
1942+
processed_gap.output[0]],
1943+
attr={'axis': 0}) \
1944+
if align.endswith('LEFT') \
1945+
else gap_neg_k_graph.make_node('Concat', [processed_gap.output[0],
1946+
const_zero.output[0]],
1947+
attr={'axis': 0})
1948+
gap_neg_k_graph.add_graph_output(gap_neg_k.output[0], TensorProto.INT64, [-1])
1949+
# pad output with gap
1950+
gap_k = body_graph.make_node('If', [is_k_noneg.output[0]])
1951+
gap_k.set_body_graph_as_attr("then_branch", gap_pos_k_graph)
1952+
gap_k.set_body_graph_as_attr("else_branch", gap_neg_k_graph)
1953+
gap_left = body_graph.make_node('Slice', [gap_k.output[0], const_zero.output[0], const_one.output[0]])
1954+
gap_right = body_graph.make_node('Slice', [gap_k.output[0], const_one.output[0], const_two.output[0]])
1955+
gap_all = body_graph.make_node('Concat', [sliced_zero.output[0], gap_left.output[0], sliced_zero.output[0],
1956+
gap_right.output[0]], attr={'axis': 0})
1957+
padded_output = body_graph.make_node('Pad', [raw_output.output[0], gap_all.output[0], padding])
1958+
cond_output = body_graph.make_node('Identity', [cond_name])
1959+
body_graph.add_graph_output(cond_output.output[0], TensorProto.BOOL, [])
1960+
body_graph.add_graph_output(padded_output.output[0], ctx.get_dtype(node.input[0]), per_loop_shape)
1961+
body_graph.add_graph_output(gap_k.output[0], TensorProto.INT64, [-1])
1962+
# make loop
1963+
cond_const = ctx.make_const(utils.make_name("cond"), np.ones((), dtype=np.bool))
1964+
main_loop = ctx.make_node('Loop', [total_k.output[0], cond_const.output[0]], output_count=2)
1965+
main_loop.set_body_graph_as_attr("body", body_graph)
1966+
# reshape output
1967+
next_padded_shape = ctx.make_node('Concat', [total_k.output[0], const_neg_one.output[0], min_dim.output[0]],
1968+
attr={'axis': 0})
1969+
reshaped_padded = ctx.make_node('Reshape', [main_loop.output[0], next_padded_shape.output[0]])
1970+
transposed_padded = ctx.make_node('Transpose', [reshaped_padded.output[0]], attr={'perm': [1, 0, 2]})
1971+
output_shape = ctx.make_node('Concat', [half_shape.output[0], total_k.output[0], const_neg_one.output[0]],
1972+
attr={'axis': 0})
1973+
reshaped_output = ctx.make_node('Reshape', [transposed_padded.output[0], output_shape.output[0]])
1974+
# compute pads
1975+
left_pads = ctx.make_node('Slice', [main_loop.output[1], const_neg_two.output[0], const_neg_one.output[0],
1976+
const_neg_one.output[0]])
1977+
flattened_left_pads = ctx.make_node('Reshape', [left_pads.output[0], const_neg_one.output[0]])
1978+
min_left_pads = ctx.make_node('ReduceMin', [flattened_left_pads.output[0]])
1979+
right_pads = ctx.make_node('Slice', [main_loop.output[1], const_neg_one.output[0], const_two.output[0],
1980+
const_neg_one.output[0]])
1981+
flattened_right_pads = ctx.make_node('Reshape', [right_pads.output[0], const_neg_one.output[0]])
1982+
min_right_pads = ctx.make_node('ReduceMin', [flattened_right_pads.output[0]])
1983+
# trim left pads
1984+
identity_left_sliced_graph = ctx.create_new_graph_with_same_config()
1985+
identity_left_sliced_graph.parent_graph = ctx
1986+
identity_left_sliced = identity_left_sliced_graph.make_node('Identity', [reshaped_output.output[0]])
1987+
identity_left_sliced_graph.add_graph_output(identity_left_sliced.output[0], ctx.get_dtype(node.input[0]),
1988+
loop_output_shape)
1989+
output_left_sliced_graph = ctx.create_new_graph_with_same_config()
1990+
output_left_sliced_graph.parent_graph = ctx
1991+
output_left_sliced = output_left_sliced_graph.make_node('Slice',
1992+
[reshaped_output.output[0], min_left_pads.output[0],
1993+
min_dim.output[0], const_neg_one.output[0]])
1994+
output_left_sliced_graph.add_graph_output(output_left_sliced.output[0], ctx.get_dtype(node.input[0]),
1995+
loop_output_shape)
1996+
left_pads_greater_than_zero = ctx.make_node('Greater', [min_left_pads.output[0], const_zero.output[0]])
1997+
final_output_left_sliced = ctx.make_node('If', [left_pads_greater_than_zero.output[0]])
1998+
final_output_left_sliced.set_body_graph_as_attr("then_branch", output_left_sliced_graph)
1999+
final_output_left_sliced.set_body_graph_as_attr("else_branch", identity_left_sliced_graph)
2000+
# trim right pads
2001+
valid_right_dim = ctx.make_node('Sub', [min_dim.output[0], min_right_pads.output[0]])
2002+
identity_right_sliced_graph = ctx.create_new_graph_with_same_config()
2003+
identity_right_sliced_graph.parent_graph = ctx
2004+
identity_right_sliced = identity_right_sliced_graph.make_node('Identity', [final_output_left_sliced.output[0]])
2005+
identity_right_sliced_graph.add_graph_output(identity_right_sliced.output[0], ctx.get_dtype(node.input[0]),
2006+
loop_output_shape)
2007+
output_right_sliced_graph = ctx.create_new_graph_with_same_config()
2008+
output_right_sliced_graph.parent_graph = ctx
2009+
output_right_sliced = output_right_sliced_graph.make_node('Slice', [final_output_left_sliced.output[0],
2010+
const_zero.output[0],
2011+
valid_right_dim.output[0],
2012+
const_neg_one.output[0]])
2013+
output_right_sliced_graph.add_graph_output(output_right_sliced.output[0], ctx.get_dtype(node.input[0]),
2014+
loop_output_shape)
2015+
right_dim_greater_than_valid = ctx.make_node('Greater', [min_dim.output[0], valid_right_dim.output[0]])
2016+
final_output_right_sliced = ctx.make_node('If', [right_dim_greater_than_valid.output[0]])
2017+
final_output_right_sliced.set_body_graph_as_attr("then_branch", output_right_sliced_graph)
2018+
final_output_right_sliced.set_body_graph_as_attr("else_branch", identity_right_sliced_graph)
2019+
# squeeze output
2020+
latest_shape = ctx.make_node('Shape', [final_output_right_sliced.output[0]])
2021+
latest_depth = ctx.make_node('Slice',
2022+
[latest_shape.output[0], const_neg_two.output[0], const_neg_one.output[0]])
2023+
need_squeeze = ctx.make_node('Equal', [latest_depth.output[0], const_one.output[0]])
2024+
identity_sliced_graph = ctx.create_new_graph_with_same_config()
2025+
identity_sliced_graph.parent_graph = ctx
2026+
identity_sliced = identity_sliced_graph.make_node('Identity', [final_output_right_sliced.output[0]])
2027+
identity_sliced_graph.add_graph_output(identity_sliced.output[0], ctx.get_dtype(node.input[0]),
2028+
raw_output_shape + [-1])
2029+
squeeze_sliced_graph = ctx.create_new_graph_with_same_config()
2030+
squeeze_sliced_graph.parent_graph = ctx
2031+
squeeze_sliced = squeeze_sliced_graph.make_node('Squeeze', [final_output_right_sliced.output[0]],
2032+
attr={'axes': [-2]})
2033+
squeeze_sliced_graph.add_graph_output(squeeze_sliced.output[0], ctx.get_dtype(node.input[0]), raw_output_shape)
2034+
shapes = node.output_shapes
2035+
dtypes = node.output_dtypes
2036+
ctx.remove_node(node.name)
2037+
squeeze_if = ctx.make_node('If', [need_squeeze.output[0]], name=node.name, outputs=node.output, shapes=shapes,
2038+
dtypes=dtypes)
2039+
squeeze_if.set_body_graph_as_attr("then_branch", squeeze_sliced_graph)
2040+
squeeze_if.set_body_graph_as_attr("else_branch", identity_sliced_graph)
2041+
2042+
18222043
@tf_op("BroadcastTo")
18232044
class BroadcastTo:
18242045
@classmethod

0 commit comments

Comments
 (0)