@@ -3182,10 +3182,10 @@ def test_matrix_diag_part_v3(self):
3182
3182
3183
3183
def func (X , K ):
3184
3184
v2 = tf .raw_ops .MatrixDiagPartV2 (input = X , k = K , padding_value = 0.123 , name = _TFOUTPUT )
3185
- v3 = tf .raw_ops .MatrixDiagPartV3 (input = X , k = K , padding_value = 0.123 , align = 'RIGHT_LEFT ' , name = _TFOUTPUT1 )
3185
+ v3 = tf .raw_ops .MatrixDiagPartV3 (input = X , k = K , padding_value = 0.123 , align = 'LEFT_RIGHT ' , name = _TFOUTPUT1 )
3186
3186
return v2 , v3
3187
3187
3188
- for x_shape in ([4 , 5 ], [2 , 3 , 4 , 5 ]):
3188
+ for x_shape in ([4 , 5 ], [2 , 3 , 4 , 5 ], [ 5 , 4 ], [ 7 , 5 ] ):
3189
3189
x_val = np .random .random (x_shape ).astype (np .float32 )
3190
3190
for raw_k in ([0 ], [1 ], [3 ], [- 1 ], [- 3 ], [1 , 2 ], [- 2 , - 1 ], [- 1 , 1 ]):
3191
3191
k_val = np .array (raw_k ).astype (np .int32 )
@@ -3235,6 +3235,116 @@ def func(x):
3235
3235
return tf .identity (y , name = _TFOUTPUT )
3236
3236
self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val })
3237
3237
3238
+ @check_opset_min_version (12 )
3239
+ @check_tf_min_version ("2.2" )
3240
+ def test_matrix_diag_v3_multi_dim (self ):
3241
+ raw_diag = [[[1.0 , 2.0 , 3.0 ],
3242
+ [4.0 , 5.0 , 6.0 ],
3243
+ [7.0 , 8.0 , 9.0 ]],
3244
+ [[10.0 , 11.0 , 12.0 ],
3245
+ [13.0 , 14.0 , 15.0 ],
3246
+ [16.0 , 17.0 , 18.0 ]]]
3247
+ diag_val = np .array (raw_diag ).astype (np .float32 )
3248
+ k_val = np .array ([- 1 , 1 ]).astype (np .int32 )
3249
+ row_val = np .array (- 1 ).astype (np .int32 )
3250
+ col_val = np .array (- 1 ).astype (np .int32 )
3251
+
3252
+ def func (diag , k , row , col ):
3253
+ return tf .raw_ops .MatrixDiagV3 (diagonal = diag , k = k , num_rows = row , num_cols = col ,
3254
+ padding_value = 0.123 , align = 'RIGHT_RIGHT' , name = _TFOUTPUT ), \
3255
+ tf .raw_ops .MatrixDiagV2 (diagonal = diag , k = k , num_rows = row , num_cols = col ,
3256
+ padding_value = 0.123 , name = _TFOUTPUT1 )
3257
+
3258
+ self ._run_test_case (func , [_OUTPUT , _OUTPUT1 ], {_INPUT : diag_val , _INPUT1 : k_val ,
3259
+ _INPUT2 : row_val , _INPUT3 : col_val })
3260
+
3261
+ @check_opset_min_version (12 )
3262
+ @check_tf_min_version ("2.2" )
3263
+ def test_matrix_diag_v3_multi_dim_min_row (self ):
3264
+ raw_diag = [[[1.0 , 2.0 , 3.0 ],
3265
+ [4.0 , 5.0 , 6.0 ]],
3266
+ [[7.0 , 8.0 , 9.0 ],
3267
+ [10.0 , 11.0 , 12.0 ]]]
3268
+ diag_val = np .array (raw_diag ).astype (np .float32 )
3269
+ k_val = np .array ([2 , 3 ]).astype (np .int32 )
3270
+ row_val = np .array (- 1 ).astype (np .int32 )
3271
+ col_val = np .array (6 ).astype (np .int32 )
3272
+
3273
+ def func (diag , k , row , col ):
3274
+ return tf .raw_ops .MatrixDiagV3 (diagonal = diag , k = k , num_rows = row , num_cols = col ,
3275
+ padding_value = 0.456 , align = 'LEFT_LEFT' , name = _TFOUTPUT )
3276
+
3277
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : diag_val , _INPUT1 : k_val ,
3278
+ _INPUT2 : row_val , _INPUT3 : col_val })
3279
+
3280
+ @check_opset_min_version (12 )
3281
+ @check_tf_min_version ("2.2" )
3282
+ def test_matrix_diag_v3_single_dim_min_col (self ):
3283
+ raw_diag = [1.0 , 2.0 , 3.0 ]
3284
+ diag_val = np .array (raw_diag ).astype (np .float32 )
3285
+ k_val = np .array (- 1 ).astype (np .int32 )
3286
+ row_val = np .array (5 ).astype (np .int32 )
3287
+ col_val = np .array (- 1 ).astype (np .int32 )
3288
+
3289
+ def func (diag , k , row , col ):
3290
+ return tf .raw_ops .MatrixDiagV3 (diagonal = diag , k = k , num_rows = row , num_cols = col ,
3291
+ padding_value = 0.789 , align = 'LEFT_RIGHT' , name = _TFOUTPUT )
3292
+
3293
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : diag_val , _INPUT1 : k_val ,
3294
+ _INPUT2 : row_val , _INPUT3 : col_val })
3295
+
3296
+ @check_opset_min_version (12 )
3297
+ @check_tf_min_version ("2.2" )
3298
+ def test_matrix_diag_v3_2single_dim_row_col (self ):
3299
+ raw_diag = [[1 , 2 , 3 ], [4 , 5 , 6 ]]
3300
+ diag_val = np .array (raw_diag ).astype (np .int64 )
3301
+ k_val = np .array (0 ).astype (np .int32 )
3302
+ row_val = np .array (3 ).astype (np .int32 )
3303
+ col_val = np .array (4 ).astype (np .int32 )
3304
+
3305
+ def func (diag , k , row , col ):
3306
+ return tf .raw_ops .MatrixDiagV3 (diagonal = diag , k = k , num_rows = row , num_cols = col ,
3307
+ padding_value = 7 , align = 'LEFT_RIGHT' , name = _TFOUTPUT ), \
3308
+ tf .raw_ops .MatrixDiag (diagonal = diag , name = _TFOUTPUT1 )
3309
+
3310
+ self ._run_test_case (func , [_OUTPUT , _OUTPUT1 ],
3311
+ {_INPUT : diag_val , _INPUT1 : k_val ,
3312
+ _INPUT2 : row_val , _INPUT3 : col_val })
3313
+
3314
+ @check_opset_min_version (12 )
3315
+ @check_tf_min_version ("2.2" )
3316
+ def test_matrix_diag_v3_1single_dim_row_col (self ):
3317
+ raw_diag = [1 , 2 , 3 , 4 , 5 ]
3318
+ diag_val = np .array (raw_diag ).astype (np .int64 )
3319
+ k_val = np .array (0 ).astype (np .int32 )
3320
+ row_val = np .array (5 ).astype (np .int32 )
3321
+ col_val = np .array (10 ).astype (np .int32 )
3322
+
3323
+ def func (diag , k , row , col ):
3324
+ return tf .raw_ops .MatrixDiagV3 (diagonal = diag , k = k , num_rows = row , num_cols = col ,
3325
+ padding_value = 7 , align = 'LEFT_RIGHT' , name = _TFOUTPUT )
3326
+
3327
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : diag_val , _INPUT1 : k_val ,
3328
+ _INPUT2 : row_val , _INPUT3 : col_val })
3329
+
3330
+ @check_opset_min_version (12 )
3331
+ @check_tf_min_version ("2.2" )
3332
+ def test_matrix_set_diag_v3 (self ):
3333
+ input_val = np .array ([[[7 , 7 , 7 , 7 ],
3334
+ [7 , 7 , 7 , 7 ],
3335
+ [7 , 7 , 7 , 7 ]],
3336
+ [[7 , 7 , 7 , 7 ],
3337
+ [7 , 7 , 7 , 7 ],
3338
+ [7 , 7 , 7 , 7 ]]]).astype (np .int64 )
3339
+ diag_val = np .array ([[1 , 2 , 3 ],
3340
+ [4 , 5 , 6 ]]).astype (np .int64 )
3341
+ k_val = np .array ([0 ])
3342
+
3343
+ def func (base_matrix , diag , k ):
3344
+ return tf .raw_ops .MatrixSetDiagV3 (input = base_matrix , diagonal = diag , k = k , align = 'RIGHT_LEFT' , name = _TFOUTPUT )
3345
+
3346
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : input_val , _INPUT1 : diag_val , _INPUT2 : k_val })
3347
+
3238
3348
3239
3349
if __name__ == '__main__' :
3240
3350
unittest_main ()
0 commit comments