Skip to content

Commit 4a3c491

Browse files
authored
Merge pull request #948 from jignparm/jignparm/matrixdiagpartv3_2
MatrixDiagPartV3: Change consts to dynamic ops
2 parents c1eb6a8 + 58958f2 commit 4a3c491

File tree

1 file changed

+60
-56
lines changed

1 file changed

+60
-56
lines changed

tf2onnx/onnx_opset/tensor.py

Lines changed: 60 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1835,94 +1835,98 @@ class MatrixDiagPartV2V3:
18351835
@classmethod
18361836
def version_11(cls, ctx, node, **kwargs):
18371837

1838-
def mkconst(npval, desc):
1839-
name = utils.make_name(node.name) + f'_{desc}'
1840-
return ctx.make_const(name, npval).output[0]
1838+
def mkconsts(values, dtype=np.int64):
1839+
ret = []
1840+
for value in values:
1841+
name = utils.make_name(node.name + '_const')
1842+
ret.append(ctx.make_const(name, np.array(value, dtype=dtype)).output[0])
1843+
return ret
18411844

18421845
# assemble MatrixDiagPart V2&V3
18431846
m = node.input[0]
18441847
m_shape = ctx.get_shape(m)
1845-
utils.make_sure(-1 not in m_shape, 'At least one dim is unknown %s', str(m_shape))
1846-
1847-
xlen = m_shape[-1]
1848-
ylen = m_shape[-2]
1849-
xlenp = xlen + 1
1850-
pads = np.zeros(2 * len(m_shape), dtype=np.int64)
1848+
m_rank = len(m_shape)
1849+
pads = np.zeros(2 * m_rank, dtype=np.int64)
18511850
pads[-2:] = [1, 1]
1851+
utils.make_sure(m_rank > 1, 'Input data should be at least 2D %s', str(m_shape))
18521852

18531853
align = 'LEFT_LEFT'
18541854
if node.op.op_type == 'MatrixDiagPartV3':
18551855
align = node.get_attr_str('align') if 'align' in node.attr else 'LEFT_RIGHT'
18561856
xalign, yalign = align.split('_')
18571857

18581858
# consts
1859-
const_zero = mkconst(np.array([0], np.int64), 'const_zero_dtype')
1860-
const_zero_float = mkconst(np.array([0], np.float32), 'const_zero_dtype_f')
1861-
const_one = mkconst(np.array([1], np.int64), 'const_one_dtype')
1862-
const_neg_one = mkconst(np.array([-1]).astype(np.int64), 'const_neg_one')
1863-
const_neg_one_float = mkconst(np.array([-1]).astype(np.float32), 'const_neg_one_f')
1864-
const_pad_vals = mkconst(pads, 'pads')
1865-
const_t = mkconst(np.array([-1, 1], np.int64), 'const_t')
1866-
const_xlen = mkconst(np.array([xlen], np.int64), 'const_xlen')
1867-
const_ylen = mkconst(np.array([ylen], np.int64), 'const_ylen')
1868-
const_xlenp = mkconst(np.array([xlenp], np.int64), 'const_xlenp')
1869-
const_stride = mkconst(np.array([xlenp + 1], np.int64), 'const_stride')
1870-
const_minxy_float = mkconst(np.array([min(xlen, ylen)], np.float32), 'const_minxy_f')
1871-
const_xmax = mkconst(np.array([xlen * xlenp + xlenp - 1], np.int64), 'const_xmax')
1872-
const_ymax = mkconst(np.array([xlenp * ylen - 1], np.int64), 'const_ymax')
1873-
const_ymax_float = mkconst(np.array([xlenp * ylen - 1], np.float32), 'const_ymax_f')
1874-
const_partial_shape = mkconst(np.asarray(m_shape[:-2], np.int64), 'partial_shape')
1875-
const_m2_shape = mkconst(np.asarray(m_shape[:-2] + [-1], np.int64), 'm2_shape')
1876-
const_gather_shape = mkconst(np.asarray(m_shape[:-2] + [1], np.int64), 'gather_shape')
1859+
const_zero_float, const_neg_one_float = mkconsts([[0], [-1]], np.float32)
1860+
const_zero, const_one, const_neg_one, const_neg_two, const_pad_vals, const_t = \
1861+
mkconsts([[0], [1], [-1], [-2], pads, [-1, 1]])
1862+
const_zero_scalar, const_one_scalar, const_neg_one_scalar = mkconsts([0, 1, -1])
1863+
1864+
m_shape = ctx.make_node('Shape', [node.input[0]]).output[0]
1865+
xlen = ctx.make_node('Gather', [m_shape, const_neg_one]).output[0]
1866+
ylen = ctx.make_node('Gather', [m_shape, const_neg_two]).output[0]
1867+
xlenp = ctx.make_node('Add', [xlen, const_one]).output[0]
1868+
stride = ctx.make_node('Add', [xlenp, const_one]).output[0]
1869+
minxy_0 = ctx.make_node('Concat', [xlen, ylen], attr={'axis': 0}).output[0]
1870+
minxy = ctx.make_node('ReduceMin', [minxy_0]).output[0]
1871+
minxy_float = ctx.make_node('Cast', [minxy], attr={'to': TensorProto.FLOAT}).output[0]
1872+
xmax_0 = ctx.make_node('Mul', [xlen, xlenp]).output[0]
1873+
xmax_1 = ctx.make_node('Add', [xmax_0, xlenp]).output[0]
1874+
xmax = ctx.make_node('Add', [xmax_1, const_neg_one]).output[0]
1875+
ymax_0 = ctx.make_node('Mul', [xlenp, ylen]).output[0]
1876+
ymax = ctx.make_node('Add', [ymax_0, const_neg_one]).output[0]
1877+
ymax_float = ctx.make_node('Cast', [ymax], attr={'to': TensorProto.FLOAT}).output[0]
1878+
partial_shape = ctx.make_node('Slice', [m_shape, const_zero, const_neg_two]).output[0]
1879+
m2_shape = ctx.make_node('Concat', [partial_shape, const_neg_one], attr={'axis': 0}).output[0]
1880+
gather_shape = ctx.make_node('Concat', [partial_shape, const_one], attr={'axis': 0}).output[0]
18771881

18781882
# get k0, k1 values. diags to be extracted
18791883
input1 = ctx.make_node('Cast', [node.input[1]], attr={'to': TensorProto.INT64})
1880-
k0 = ctx.make_node('ReduceMin', [input1.output[0]])
1881-
k1 = ctx.make_node('ReduceMax', [input1.output[0]])
1882-
k1_scalar = ctx.make_node('Squeeze', [k1.output[0]])
1884+
k0 = ctx.make_node('ReduceMin', [input1.output[0]]).output[0]
1885+
k1 = ctx.make_node('ReduceMax', [input1.output[0]]).output[0]
1886+
k0_scalar = ctx.make_node('Squeeze', [k0]).output[0]
1887+
k1_scalar = ctx.make_node('Squeeze', [k1]).output[0]
18831888
m_padded = ctx.make_node('Pad', [m, const_pad_vals, node.input[2]])
18841889

18851890
# starting indexes for super diagonals
1886-
xstart_0 = ctx.make_node('Cast', [k0.output[0]], attr={'to': TensorProto.FLOAT})
1891+
xstart_0 = ctx.make_node('Cast', [k0_scalar], attr={'to': TensorProto.FLOAT})
18871892
xstart_1 = ctx.make_node('Max', [const_zero_float, xstart_0.output[0]])
18881893
xstart_2 = ctx.make_node('Cast', [xstart_1.output[0]], attr={'to': TensorProto.INT64})
1889-
xstart_3 = ctx.make_node('Add', [xstart_2.output[0], const_neg_one])
1890-
xstart_4 = ctx.make_node('Range', [k1_scalar.output[0], xstart_3.output[0], const_neg_one])
1894+
xstart_3 = ctx.make_node('Add', [xstart_2.output[0], const_neg_one_scalar])
1895+
xstart_4 = ctx.make_node('Range', [k1_scalar, xstart_3.output[0], const_neg_one_scalar])
18911896
xstart = ctx.make_node('Reshape', [xstart_4.output[0], const_t])
18921897

18931898
# starting indexes for sub diagonals
1894-
ystart_0 = ctx.make_node('Cast', [k1.output[0]], attr={'to': TensorProto.FLOAT})
1899+
ystart_0 = ctx.make_node('Cast', [k1_scalar], attr={'to': TensorProto.FLOAT})
18951900
ystart_1 = ctx.make_node('Min', [const_neg_one_float, ystart_0.output[0]])
18961901
ystart_2 = ctx.make_node('Cast', [ystart_1.output[0]], attr={'to': TensorProto.INT64})
1897-
ystart_2_scalar = ctx.make_node('Squeeze', [ystart_2.output[0]])
1898-
ystart_3 = ctx.make_node('Add', [k0.output[0], const_neg_one])
1899-
ystart_4 = ctx.make_node('Range', [ystart_2_scalar.output[0], ystart_3.output[0], const_neg_one])
1902+
ystart_3 = ctx.make_node('Add', [k0_scalar, const_neg_one_scalar])
1903+
ystart_4 = ctx.make_node('Range', [ystart_2.output[0], ystart_3.output[0], const_neg_one_scalar])
19001904
ystart = ctx.make_node('Reshape', [ystart_4.output[0], const_t])
19011905

1902-
xmax_0 = ctx.make_node('Mul', [xstart.output[0], const_xlenp])
1903-
xmax = ctx.make_node('Sub', [const_xmax, xmax_0.output[0]])
1906+
xmax_0 = ctx.make_node('Mul', [xstart.output[0], xlenp])
1907+
xmax = ctx.make_node('Sub', [xmax, xmax_0.output[0]])
19041908
xmax_float = ctx.make_node('Cast', [xmax.output[0]], attr={'to': TensorProto.FLOAT})
19051909

19061910
# lengths of super/sub diags to extract
1907-
xsize_0 = ctx.make_node('Sub', [const_xlen, xstart.output[0]])
1911+
xsize_0 = ctx.make_node('Sub', [xlen, xstart.output[0]])
19081912
xsize_1 = ctx.make_node('Cast', [xsize_0.output[0]], attr={'to': TensorProto.FLOAT})
1909-
xsize_2 = ctx.make_node('Min', [xsize_1.output[0], const_minxy_float])
1913+
xsize_2 = ctx.make_node('Min', [xsize_1.output[0], minxy_float])
19101914
xsize = ctx.make_node('Cast', [xsize_2.output[0]], attr={'to': TensorProto.INT64})
1911-
ysize_0 = ctx.make_node('Add', [const_ylen, ystart.output[0]])
1915+
ysize_0 = ctx.make_node('Add', [ylen, ystart.output[0]])
19121916
ysize_1 = ctx.make_node('Cast', [ysize_0.output[0]], attr={'to': TensorProto.FLOAT})
1913-
ysize_2 = ctx.make_node('Min', [ysize_1.output[0], const_minxy_float])
1917+
ysize_2 = ctx.make_node('Min', [ysize_1.output[0], minxy_float])
19141918
ysize = ctx.make_node('Cast', [ysize_2.output[0]], attr={'to': TensorProto.INT64})
19151919
diagsize = ctx.make_node('Concat', [xsize.output[0], ysize.output[0]], attr={'axis': 0})
19161920
maxsize = ctx.make_node('ReduceMax', [diagsize.output[0]], attr={'keep_dims': 0})
19171921
maxsize_0 = ctx.make_node('Reshape', [maxsize.output[0], const_neg_one])
19181922
maxsize_scalar = ctx.make_node('Squeeze', [maxsize.output[0]])
19191923

1920-
diagdistances_0 = ctx.make_node('Range', [const_zero, maxsize_scalar.output[0], const_one])
1921-
diagdistances = ctx.make_node('Mul', [diagdistances_0.output[0], const_stride])
1924+
diagdistances_0 = ctx.make_node('Range', [const_zero_scalar, maxsize_scalar.output[0], const_one_scalar])
1925+
diagdistances = ctx.make_node('Mul', [diagdistances_0.output[0], stride])
19221926

19231927
def right_align(sizes, indices, starts, maxval):
19241928
op1 = ctx.make_node('Sub', [maxsize.output[0], sizes.output[0]])
1925-
op2 = ctx.make_node('Mul', [op1.output[0], const_stride])
1929+
op2 = ctx.make_node('Mul', [op1.output[0], stride])
19261930
op3 = ctx.make_node('Sub', [indices.output[0], op2.output[0]])
19271931
op4 = ctx.make_node('Less', [op3.output[0], starts.output[0]])
19281932
op5 = ctx.make_node('Where', [op4.output[0], maxval, op3.output[0]])
@@ -1932,48 +1936,48 @@ def right_align(sizes, indices, starts, maxval):
19321936
xdiags_0 = ctx.make_node('Add', [xstart.output[0], diagdistances.output[0]])
19331937
xdiags_1 = ctx.make_node('Cast', [xdiags_0.output[0]], attr={'to': TensorProto.FLOAT})
19341938
if xalign == 'RIGHT':
1935-
xdiags = right_align(xsize, xdiags_0, xstart, const_ymax)
1939+
xdiags = right_align(xsize, xdiags_0, xstart, ymax)
19361940
else:
19371941
xdiags_2 = ctx.make_node('Min', [xdiags_1.output[0], xmax_float.output[0]])
19381942
xdiags = ctx.make_node('Cast', [xdiags_2.output[0]], attr={'to': TensorProto.INT64})
19391943

19401944
ydiags_0_ = ctx.make_node('Abs', [ystart.output[0]])
1941-
ydiags_1 = ctx.make_node('Mul', [ydiags_0_.output[0], const_xlenp])
1945+
ydiags_1 = ctx.make_node('Mul', [ydiags_0_.output[0], xlenp])
19421946
ydiags_2 = ctx.make_node('Add', [ydiags_1.output[0], diagdistances.output[0]])
19431947
ydiags_3 = ctx.make_node('Cast', [ydiags_2.output[0]], attr={'to': TensorProto.FLOAT})
19441948
if yalign == 'RIGHT':
1945-
ydiags = right_align(ysize, ydiags_2, ydiags_1, const_ymax)
1949+
ydiags = right_align(ysize, ydiags_2, ydiags_1, ymax)
19461950
else:
1947-
ydiags_4 = ctx.make_node('Min', [ydiags_3.output[0], const_ymax_float])
1951+
ydiags_4 = ctx.make_node('Min', [ydiags_3.output[0], ymax_float])
19481952
ydiags = ctx.make_node('Cast', [ydiags_4.output[0]], attr={'to': TensorProto.INT64})
19491953

19501954
# flatten last dimension of matrix
1951-
m2 = ctx.make_node('Reshape', [m_padded.output[0], const_m2_shape])
1955+
m2 = ctx.make_node('Reshape', [m_padded.output[0], m2_shape])
19521956

19531957
diags_0 = ctx.make_node('Concat', [xdiags.output[0], ydiags.output[0]], attr={'axis': 0})
19541958
diags_1 = ctx.make_node('Reshape', [diags_0.output[0], const_neg_one])
1955-
diags_2 = ctx.make_node('Expand', [diags_1.output[0], const_gather_shape])
1959+
diags_2 = ctx.make_node('Expand', [diags_1.output[0], gather_shape])
19561960
diags = ctx.make_node('GatherElements', [m2.output[0], diags_2.output[0]], attr={'axis': -1})
19571961

19581962
def compute_out_shape(k0_k1_same=False):
19591963
g = ctx.create_new_graph_with_same_config()
19601964
g.parent_graph = ctx
19611965
if k0_k1_same:
1962-
dims = [const_partial_shape, maxsize_0.output[0]]
1966+
dims = [partial_shape, maxsize_0.output[0]]
19631967
else:
1964-
dims = [const_partial_shape, const_neg_one, maxsize_0.output[0]]
1968+
dims = [partial_shape, const_neg_one, maxsize_0.output[0]]
19651969
outshape = g.make_node('Concat', dims, attr={'axis': 0})
19661970
g.add_graph_output(outshape.output[0], TensorProto.INT64, [-1])
19671971
return g
19681972

19691973
# if k0=k1, rank of output matrix is 1 less than usual
19701974
# hence, need 'If' to compute right output matrix shape
1971-
k0_k1_same = ctx.make_node('Equal', [k1.output[0], k0.output[0]])
1975+
k0_k1_same = ctx.make_node('Equal', [k1, k0])
19721976
if_node = ctx.make_node('If', [k0_k1_same.output[0]])
19731977
if_node.set_body_graph_as_attr('then_branch', compute_out_shape(True))
19741978
if_node.set_body_graph_as_attr('else_branch', compute_out_shape(False))
19751979

1976-
shapes = [-1] * len(m_shape)
1980+
shapes = ctx.get_shape(node.output[0])
19771981
dtypes = node.output_dtypes
19781982
ctx.remove_node(node.name)
19791983
ctx.make_node('Reshape', [diags.output[0], if_node.output[0]], name=node.name, outputs=node.output,

0 commit comments

Comments
 (0)