@@ -2207,11 +2207,13 @@ def compute_out_shape(k0_k1_same=False):
2207
2207
ctx .set_shape (consumer .output [0 ], shapes )
2208
2208
2209
2209
2210
- @tf_op (" MatrixDiagV3" )
2211
- class MatrixDiagV3 :
2210
+ @tf_op ([ "MatrixDiag" , "MatrixDiagV2" , " MatrixDiagV3"] )
2211
+ class MatrixDiag :
2212
2212
@classmethod
2213
2213
def version_12 (cls , ctx , node , ** kwargs ):
2214
2214
# Assemble MatrixDiagV3 by ReverseSequence
2215
+ argc = len (node .input )
2216
+
2215
2217
def mkconsts (values ):
2216
2218
return [ctx .make_const (utils .make_name ('const' ), \
2217
2219
np .array (value ).astype (np .int64 )).output [0 ] for value in values ]
@@ -2230,6 +2232,9 @@ def normalize(name):
2230
2232
reshaped = mknode ("Reshape" , [casted , minus_one ])
2231
2233
return reshaped
2232
2234
2235
+ def cast (name ):
2236
+ return mknode ("Cast" , [name ], attr = {"to" : ctx .get_dtype (node .input [0 ])})
2237
+
2233
2238
def processdiag ():
2234
2239
# unsqueeze diag if necessary
2235
2240
diag = node .input [0 ]
@@ -2241,7 +2246,7 @@ def processdiag():
2241
2246
2242
2247
diag_shape = mknode ("Shape" , [diag ])
2243
2248
diag_depth = mknode ("Slice" , [diag_shape , minus_two , minus_one ])
2244
- k = normalize (node .input [1 ])
2249
+ k = normalize (node .input [1 ]) if argc > 1 else zeo
2245
2250
k_min , k_max = mknode ("ReduceMin" , [k ]), mknode ("ReduceMax" , [k ])
2246
2251
k_max_nxt = mknode ("Add" , [k_max , one ])
2247
2252
k_depth = mknode ("Sub" , [k_max_nxt , k_min ])
@@ -2272,8 +2277,10 @@ def squeeze(name):
2272
2277
2273
2278
# gather inputs
2274
2279
diag , k , k_min , k_max , k_max_nxt = processdiag ()
2275
- row , col , pad , align = normalize (node .input [2 ]), normalize (node .input [3 ]), \
2276
- node .input [4 ], node .get_attr_str ("align" )
2280
+ row , col , pad , align = normalize (node .input [2 ]) if argc > 2 else minus_one , \
2281
+ normalize (node .input [3 ]) if argc > 3 else minus_one , \
2282
+ node .input [4 ] if argc > 4 else cast (zeo ), \
2283
+ node .get_attr_str ("align" ) if "align" in node .attr else "LEFT_LEFT"
2277
2284
2278
2285
diag_shape = mknode ("Shape" , [diag ])
2279
2286
diag_rank = mknode ("Shape" , [diag_shape ])
@@ -2580,12 +2587,12 @@ def normalize():
2580
2587
# make matrix of bool
2581
2588
ctx .set_dtype (ones_diag .output [0 ], TensorProto .INT64 )
2582
2589
ones_matrix = ctx .make_node ("MatrixDiagV3" , [ones_diag .output [0 ], k , row , col , zeo ], attr )
2583
- MatrixDiagV3 .version_12 (ctx , ones_matrix )
2590
+ MatrixDiag .version_12 (ctx , ones_matrix )
2584
2591
ones_bool = mknode ("Equal" , [ones_matrix .output [0 ], one ])
2585
2592
2586
2593
# make matrix out of diag
2587
2594
diag_matrix = ctx .make_node ("MatrixDiagV3" , [diag , k , row , col , cast (zeo )], attr )
2588
- MatrixDiagV3 .version_12 (ctx , diag_matrix )
2595
+ MatrixDiag .version_12 (ctx , diag_matrix )
2589
2596
2590
2597
shapes = node .output_shapes
2591
2598
dtypes = node .output_dtypes
0 commit comments