@@ -2052,6 +2052,140 @@ def version_11(cls, ctx, node, **kwargs):
2052
2052
squeeze_if .set_body_graph_as_attr ("then_branch" , squeeze_sliced_graph )
2053
2053
squeeze_if .set_body_graph_as_attr ("else_branch" , identity_sliced_graph )
2054
2054
2055
+ @classmethod
2056
+ def version_12 (cls , ctx , node , ** kwargs ):
2057
+
2058
+ def mkconst (npval , desc ):
2059
+ name = utils .make_name (node .name ) + f'_{ desc } '
2060
+ return ctx .make_const (name , npval ).output [0 ]
2061
+
2062
+ # assemble MatrixDiagPart V2&V3
2063
+ m = node .input [0 ]
2064
+ m_shape = ctx .get_shape (m )
2065
+ utils .make_sure (- 1 not in m_shape , 'At least one dim is unknown %s' , str (m_shape ))
2066
+
2067
+ xlen = m_shape [- 1 ]
2068
+ ylen = m_shape [- 2 ]
2069
+ xlenp = xlen + 1
2070
+ pads = np .zeros (2 * len (m_shape ), dtype = np .int64 )
2071
+ pads [- 2 :] = [1 , 1 ]
2072
+
2073
+ align = 'LEFT_LEFT'
2074
+ if node .op .op_type == 'MatrixDiagPartV3' :
2075
+ align = node .get_attr_str ('align' ) if 'align' in node .attr else 'LEFT_RIGHT'
2076
+ xalign , yalign = align .split ('_' )
2077
+
2078
+ # consts
2079
+ const_neg_one = mkconst (np .array ([- 1 ]).astype (np .int64 ), 'const_neg_one' )
2080
+ const_pad_vals = mkconst (pads , 'pads' )
2081
+ const_zero = mkconst (np .array ([0 ], np .int64 ), 'const_zero_dtype' )
2082
+ const_one = mkconst (np .array ([1 ], np .int64 ), 'const_one_dtype' )
2083
+ const_t = mkconst (np .array ([- 1 , 1 ], np .int64 ), 'const_t' )
2084
+ const_xlen = mkconst (np .array ([xlen ], np .int64 ), 'const_xlen' )
2085
+ const_ylen = mkconst (np .array ([ylen ], np .int64 ), 'const_ylen' )
2086
+ const_stride = mkconst (np .array ([xlenp + 1 ], np .int64 ), 'const_stride' )
2087
+ const_xlenp = mkconst (np .array ([xlenp ], np .int64 ), 'const_xlenp' )
2088
+ const_minxy = mkconst (np .array ([min (xlen , ylen )], np .int64 ), 'const_minxy' )
2089
+ const_xmax = mkconst (np .array ([xlen * xlenp + xlenp - 1 ], np .int64 ), 'const_xmax' )
2090
+ const_ymax = mkconst (np .array ([xlenp * ylen - 1 ], np .int64 ), 'const_ymax' )
2091
+ const_partial_shape = mkconst (np .asarray (m_shape [:- 2 ], np .int64 ), 'partial_shape' )
2092
+ const_m2_shape = mkconst (np .asarray (m_shape [:- 2 ] + [- 1 ], np .int64 ), 'm2_shape' )
2093
+ const_gather_shape = mkconst (np .asarray (m_shape [:- 2 ] + [1 ], np .int64 ), 'gather_shape' )
2094
+
2095
+ # get k0, k1 values. diags to be extracted
2096
+ input1 = ctx .make_node ('Cast' , [node .input [1 ]], attr = {'to' : TensorProto .INT64 })
2097
+ k0 = ctx .make_node ('ReduceMin' , [input1 .output [0 ]])
2098
+ k1 = ctx .make_node ('ReduceMax' , [input1 .output [0 ]])
2099
+ k1_scalar = ctx .make_node ('Squeeze' , [k1 .output [0 ]])
2100
+ m_padded = ctx .make_node ('Pad' , [m , const_pad_vals , node .input [2 ]])
2101
+
2102
+ # starting index for super diagonals
2103
+ xstart_0 = ctx .make_node ('Max' , [const_zero , k0 .output [0 ]])
2104
+ xstart_1 = ctx .make_node ('Add' , [xstart_0 .output [0 ], const_neg_one ])
2105
+ xstart_2 = ctx .make_node ('Range' , [k1_scalar .output [0 ], xstart_1 .output [0 ], const_neg_one ])
2106
+ xstart = ctx .make_node ('Reshape' , [xstart_2 .output [0 ], const_t ])
2107
+
2108
+ # starting indices for sub diagonals
2109
+ ystart_0 = ctx .make_node ('Min' , [const_neg_one , k1 .output [0 ]])
2110
+ ystart_0_scalar = ctx .make_node ('Squeeze' , [ystart_0 .output [0 ]])
2111
+ ystart_1 = ctx .make_node ('Add' , [k0 .output [0 ], const_neg_one ])
2112
+ ystart_2 = ctx .make_node ('Range' , [ystart_0_scalar .output [0 ], ystart_1 .output [0 ], const_neg_one ])
2113
+ ystart = ctx .make_node ('Reshape' , [ystart_2 .output [0 ], const_t ])
2114
+
2115
+ xmax_0 = ctx .make_node ('Mul' , [xstart .output [0 ], const_xlenp ])
2116
+ xmax = ctx .make_node ('Sub' , [const_xmax , xmax_0 .output [0 ]])
2117
+
2118
+ # lengths of super/sub diags to extract
2119
+ xsize_0 = ctx .make_node ('Sub' , [const_xlen , xstart .output [0 ]])
2120
+ xsize = ctx .make_node ('Min' , [xsize_0 .output [0 ], const_minxy ])
2121
+ ysize_0 = ctx .make_node ('Add' , [const_ylen , ystart .output [0 ]])
2122
+ ysize = ctx .make_node ('Min' , [ysize_0 .output [0 ], const_minxy ])
2123
+ diagsize = ctx .make_node ('Concat' , [xsize .output [0 ], ysize .output [0 ]], attr = {'axis' : 0 })
2124
+ maxsize = ctx .make_node ('ReduceMax' , [diagsize .output [0 ]], attr = {'keep_dims' : 0 })
2125
+ maxsize_0 = ctx .make_node ('Reshape' , [maxsize .output [0 ], const_neg_one ])
2126
+ maxsize_scalar = ctx .make_node ('Squeeze' , [maxsize .output [0 ]])
2127
+
2128
+ diagdistances_0 = ctx .make_node ('Range' , [const_zero , maxsize_scalar .output [0 ], const_one ])
2129
+ diagdistances = ctx .make_node ('Mul' , [diagdistances_0 .output [0 ], const_stride ])
2130
+
2131
+ def right_align (sizes , indices , starts , maxval ):
2132
+ op1 = ctx .make_node ('Sub' , [maxsize .output [0 ], sizes .output [0 ]])
2133
+ op2 = ctx .make_node ('Mul' , [op1 .output [0 ], const_stride ])
2134
+ op3 = ctx .make_node ('Sub' , [indices .output [0 ], op2 .output [0 ]])
2135
+ op4 = ctx .make_node ('Less' , [op3 .output [0 ], starts .output [0 ]])
2136
+ op5 = ctx .make_node ('Where' , [op4 .output [0 ], maxval , op3 .output [0 ]])
2137
+ return op5
2138
+
2139
+ # xdiags, ydiags contain indices of diagonal elements
2140
+ xdiags_0 = ctx .make_node ('Add' , [xstart .output [0 ], diagdistances .output [0 ]])
2141
+ if xalign == 'RIGHT' :
2142
+ xdiags = right_align (xsize , xdiags_0 , xstart , const_ymax )
2143
+ else :
2144
+ xdiags = ctx .make_node ('Min' , [xdiags_0 .output [0 ], xmax .output [0 ]])
2145
+
2146
+ ydiags_0_ = ctx .make_node ('Abs' , [ystart .output [0 ]])
2147
+ ydiags_1 = ctx .make_node ('Mul' , [ydiags_0_ .output [0 ], const_xlenp ])
2148
+ ydiags_2 = ctx .make_node ('Add' , [ydiags_1 .output [0 ], diagdistances .output [0 ]])
2149
+ if yalign == 'RIGHT' :
2150
+ ydiags = right_align (ysize , ydiags_2 , ydiags_1 , const_ymax )
2151
+ else :
2152
+ ydiags = ctx .make_node ('Min' , [ydiags_2 .output [0 ], const_ymax ])
2153
+
2154
+ # flatten last dimension of matrix
2155
+ m2 = ctx .make_node ('Reshape' , [m_padded .output [0 ], const_m2_shape ])
2156
+
2157
+ diags_0 = ctx .make_node ('Concat' , [xdiags .output [0 ], ydiags .output [0 ]], attr = {'axis' : 0 })
2158
+ diags_1 = ctx .make_node ('Reshape' , [diags_0 .output [0 ], const_neg_one ])
2159
+ diags_2 = ctx .make_node ('Expand' , [diags_1 .output [0 ], const_gather_shape ])
2160
+ diags = ctx .make_node ('GatherElements' , [m2 .output [0 ], diags_2 .output [0 ]], attr = {'axis' : - 1 })
2161
+
2162
+ # if k0=k1, rank of output matrix is 1 less than usual.
2163
+ # hence, need 'If' to compute right output matrix shape
2164
+ def compute_out_shape (k0_k1_same = False ):
2165
+ g = ctx .create_new_graph_with_same_config ()
2166
+ g .parent_graph = ctx
2167
+ if k0_k1_same :
2168
+ outshape = g .make_node ('Concat' , [const_partial_shape , maxsize_0 .output [0 ]], attr = {'axis' : 0 })
2169
+ else :
2170
+ outshape = g .make_node ('Concat' , [const_partial_shape , const_neg_one , maxsize_0 .output [0 ]], attr = {'axis' : 0 })
2171
+ g .add_graph_output (outshape .output [0 ], TensorProto .INT64 , [- 1 ])
2172
+ return g
2173
+
2174
+ k0_k1_same = ctx .make_node ('Equal' , [k1 .output [0 ], k0 .output [0 ]])
2175
+ if_node = ctx .make_node ('If' , [k0_k1_same .output [0 ]])
2176
+ if_node .set_body_graph_as_attr ('then_branch' , compute_out_shape (True ))
2177
+ if_node .set_body_graph_as_attr ('else_branch' , compute_out_shape (False ))
2178
+
2179
+ shapes = [- 1 ] * len (m_shape )
2180
+ dtypes = node .output_dtypes
2181
+ ctx .remove_node (node .name )
2182
+ ctx .make_node ('Reshape' , [diags .output [0 ], if_node .output [0 ]], name = node .name , outputs = node .output ,
2183
+ shapes = [shapes ], dtypes = dtypes )
2184
+
2185
+ for consumer in ctx .find_output_consumers (node .output [0 ]):
2186
+ if consumer .type == 'Identity' :
2187
+ ctx .set_shape (consumer .output [0 ], shapes )
2188
+
2055
2189
2056
2190
@tf_op ("BroadcastTo" )
2057
2191
class BroadcastTo :
0 commit comments