Skip to content

Commit c60ff32

Browse files
committed
Add non-loop version of MatrixDiagPartV3, opset=12
1 parent 9f5d63b commit c60ff32

File tree

1 file changed

+134
-0
lines changed

1 file changed

+134
-0
lines changed

tf2onnx/onnx_opset/tensor.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2052,6 +2052,140 @@ def version_11(cls, ctx, node, **kwargs):
20522052
squeeze_if.set_body_graph_as_attr("then_branch", squeeze_sliced_graph)
20532053
squeeze_if.set_body_graph_as_attr("else_branch", identity_sliced_graph)
20542054

2055+
@classmethod
2056+
def version_12(cls, ctx, node, **kwargs):
2057+
2058+
def mkconst(npval, desc):
2059+
name = utils.make_name(node.name) + f'_{desc}'
2060+
return ctx.make_const(name, npval).output[0]
2061+
2062+
# assemble MatrixDiagPart V2&V3
2063+
m = node.input[0]
2064+
m_shape = ctx.get_shape(m)
2065+
utils.make_sure(-1 not in m_shape, 'At least one dim is unknown %s', str(m_shape))
2066+
2067+
xlen = m_shape[-1]
2068+
ylen = m_shape[-2]
2069+
xlenp = xlen + 1
2070+
pads = np.zeros(2 * len(m_shape), dtype=np.int64)
2071+
pads[-2:] = [1, 1]
2072+
2073+
align = 'LEFT_LEFT'
2074+
if node.op.op_type == 'MatrixDiagPartV3':
2075+
align = node.get_attr_str('align') if 'align' in node.attr else 'LEFT_RIGHT'
2076+
xalign, yalign = align.split('_')
2077+
2078+
# consts
2079+
const_neg_one = mkconst(np.array([-1]).astype(np.int64), 'const_neg_one')
2080+
const_pad_vals = mkconst(pads, 'pads')
2081+
const_zero = mkconst(np.array([0], np.int64), 'const_zero_dtype')
2082+
const_one = mkconst(np.array([1], np.int64), 'const_one_dtype')
2083+
const_t = mkconst(np.array([-1, 1], np.int64), 'const_t')
2084+
const_xlen = mkconst(np.array([xlen], np.int64), 'const_xlen')
2085+
const_ylen = mkconst(np.array([ylen], np.int64), 'const_ylen')
2086+
const_stride = mkconst(np.array([xlenp + 1], np.int64), 'const_stride')
2087+
const_xlenp = mkconst(np.array([xlenp], np.int64), 'const_xlenp')
2088+
const_minxy = mkconst(np.array([min(xlen, ylen)], np.int64), 'const_minxy')
2089+
const_xmax = mkconst(np.array([xlen * xlenp + xlenp - 1], np.int64), 'const_xmax')
2090+
const_ymax = mkconst(np.array([xlenp * ylen - 1], np.int64), 'const_ymax')
2091+
const_partial_shape = mkconst(np.asarray(m_shape[:-2], np.int64), 'partial_shape')
2092+
const_m2_shape = mkconst(np.asarray(m_shape[:-2] + [-1], np.int64), 'm2_shape')
2093+
const_gather_shape = mkconst(np.asarray(m_shape[:-2] + [1], np.int64), 'gather_shape')
2094+
2095+
# get k0, k1 values. diags to be extracted
2096+
input1 = ctx.make_node('Cast', [node.input[1]], attr={'to': TensorProto.INT64})
2097+
k0 = ctx.make_node('ReduceMin', [input1.output[0]])
2098+
k1 = ctx.make_node('ReduceMax', [input1.output[0]])
2099+
k1_scalar = ctx.make_node('Squeeze', [k1.output[0]])
2100+
m_padded = ctx.make_node('Pad', [m, const_pad_vals, node.input[2]])
2101+
2102+
# starting index for super diagonals
2103+
xstart_0 = ctx.make_node('Max', [const_zero, k0.output[0]])
2104+
xstart_1 = ctx.make_node('Add', [xstart_0.output[0], const_neg_one])
2105+
xstart_2 = ctx.make_node('Range', [k1_scalar.output[0], xstart_1.output[0], const_neg_one])
2106+
xstart = ctx.make_node('Reshape', [xstart_2.output[0], const_t])
2107+
2108+
# starting indices for sub diagonals
2109+
ystart_0 = ctx.make_node('Min', [const_neg_one, k1.output[0]])
2110+
ystart_0_scalar = ctx.make_node('Squeeze', [ystart_0.output[0]])
2111+
ystart_1 = ctx.make_node('Add', [k0.output[0], const_neg_one])
2112+
ystart_2 = ctx.make_node('Range', [ystart_0_scalar.output[0], ystart_1.output[0], const_neg_one])
2113+
ystart = ctx.make_node('Reshape', [ystart_2.output[0], const_t])
2114+
2115+
xmax_0 = ctx.make_node('Mul', [xstart.output[0], const_xlenp])
2116+
xmax = ctx.make_node('Sub', [const_xmax, xmax_0.output[0]])
2117+
2118+
# lengths of super/sub diags to extract
2119+
xsize_0 = ctx.make_node('Sub', [const_xlen, xstart.output[0]])
2120+
xsize = ctx.make_node('Min', [xsize_0.output[0], const_minxy])
2121+
ysize_0 = ctx.make_node('Add', [const_ylen, ystart.output[0]])
2122+
ysize = ctx.make_node('Min', [ysize_0.output[0], const_minxy])
2123+
diagsize = ctx.make_node('Concat', [xsize.output[0], ysize.output[0]], attr={'axis': 0})
2124+
maxsize = ctx.make_node('ReduceMax', [diagsize.output[0]], attr={'keep_dims': 0})
2125+
maxsize_0 = ctx.make_node('Reshape', [maxsize.output[0], const_neg_one])
2126+
maxsize_scalar = ctx.make_node('Squeeze', [maxsize.output[0]])
2127+
2128+
diagdistances_0 = ctx.make_node('Range', [const_zero, maxsize_scalar.output[0], const_one])
2129+
diagdistances = ctx.make_node('Mul', [diagdistances_0.output[0], const_stride])
2130+
2131+
def right_align(sizes, indices, starts, maxval):
2132+
op1 = ctx.make_node('Sub', [maxsize.output[0], sizes.output[0]])
2133+
op2 = ctx.make_node('Mul', [op1.output[0], const_stride])
2134+
op3 = ctx.make_node('Sub', [indices.output[0], op2.output[0]])
2135+
op4 = ctx.make_node('Less', [op3.output[0], starts.output[0]])
2136+
op5 = ctx.make_node('Where', [op4.output[0], maxval, op3.output[0]])
2137+
return op5
2138+
2139+
# xdiags, ydiags contain indices of diagonal elements
2140+
xdiags_0 = ctx.make_node('Add', [xstart.output[0], diagdistances.output[0]])
2141+
if xalign == 'RIGHT':
2142+
xdiags = right_align(xsize, xdiags_0, xstart, const_ymax)
2143+
else:
2144+
xdiags = ctx.make_node('Min', [xdiags_0.output[0], xmax.output[0]])
2145+
2146+
ydiags_0_ = ctx.make_node('Abs', [ystart.output[0]])
2147+
ydiags_1 = ctx.make_node('Mul', [ydiags_0_.output[0], const_xlenp])
2148+
ydiags_2 = ctx.make_node('Add', [ydiags_1.output[0], diagdistances.output[0]])
2149+
if yalign == 'RIGHT':
2150+
ydiags = right_align(ysize, ydiags_2, ydiags_1, const_ymax)
2151+
else:
2152+
ydiags = ctx.make_node('Min', [ydiags_2.output[0], const_ymax])
2153+
2154+
# flatten last dimension of matrix
2155+
m2 = ctx.make_node('Reshape', [m_padded.output[0], const_m2_shape])
2156+
2157+
diags_0 = ctx.make_node('Concat', [xdiags.output[0], ydiags.output[0]], attr={'axis': 0})
2158+
diags_1 = ctx.make_node('Reshape', [diags_0.output[0], const_neg_one])
2159+
diags_2 = ctx.make_node('Expand', [diags_1.output[0], const_gather_shape])
2160+
diags = ctx.make_node('GatherElements', [m2.output[0], diags_2.output[0]], attr={'axis': -1})
2161+
2162+
# if k0=k1, rank of output matrix is 1 less than usual.
2163+
# hence, need 'If' to compute right output matrix shape
2164+
def compute_out_shape(k0_k1_same=False):
2165+
g = ctx.create_new_graph_with_same_config()
2166+
g.parent_graph = ctx
2167+
if k0_k1_same:
2168+
outshape = g.make_node('Concat', [const_partial_shape, maxsize_0.output[0]], attr={'axis': 0})
2169+
else:
2170+
outshape = g.make_node('Concat', [const_partial_shape, const_neg_one, maxsize_0.output[0]], attr={'axis': 0})
2171+
g.add_graph_output(outshape.output[0], TensorProto.INT64, [-1])
2172+
return g
2173+
2174+
k0_k1_same = ctx.make_node('Equal', [k1.output[0], k0.output[0]])
2175+
if_node = ctx.make_node('If', [k0_k1_same.output[0]])
2176+
if_node.set_body_graph_as_attr('then_branch', compute_out_shape(True))
2177+
if_node.set_body_graph_as_attr('else_branch', compute_out_shape(False))
2178+
2179+
shapes = [-1] * len(m_shape)
2180+
dtypes = node.output_dtypes
2181+
ctx.remove_node(node.name)
2182+
ctx.make_node('Reshape', [diags.output[0], if_node.output[0]], name=node.name, outputs=node.output,
2183+
shapes=[shapes], dtypes=dtypes)
2184+
2185+
for consumer in ctx.find_output_consumers(node.output[0]):
2186+
if consumer.type == 'Identity':
2187+
ctx.set_shape(consumer.output[0], shapes)
2188+
20552189

20562190
@tf_op("BroadcastTo")
20572191
class BroadcastTo:

0 commit comments

Comments
 (0)