@@ -1256,8 +1256,8 @@ def any_version_after9(cls, opset, ctx, node, **kwargs):
1256
1256
depth = GraphBuilder (ctx ).make_unsqueeze ({'data' : node .input [1 ], 'axes' : [0 ]})
1257
1257
on_value = node .input [2 ]
1258
1258
off_value = node .input [3 ]
1259
- on_value = ctx . make_node ( "Unsqueeze" , [ on_value ], attr = { " axes" : [0 ]}). output [ 0 ]
1260
- off_value = ctx . make_node ( "Unsqueeze" , [ off_value ], attr = { " axes" : [0 ]}). output [ 0 ]
1259
+ on_value = GraphBuilder ( ctx ). make_unsqueeze ({ 'data' : on_value , ' axes' : [0 ]})
1260
+ off_value = GraphBuilder ( ctx ). make_unsqueeze ({ 'data' : off_value , ' axes' : [0 ]})
1261
1261
off_on_value = ctx .make_node ("Concat" , [off_value , on_value ], attr = {"axis" : 0 }).output [0 ]
1262
1262
1263
1263
indices = node .input [0 ]
@@ -2385,15 +2385,17 @@ def normalize():
2385
2385
pad_length_2 = body_graph .make_node ('Concat' , [zeo , pad_length .output [0 ]], attr = {'axis' : 0 })
2386
2386
padded_range = body_graph .make_node ('Pad' , [sliced_range .output [0 ], pad_length_2 .output [0 ]])
2387
2387
# opset == 11, no need to change unsqueeze
2388
- unsqueezed_range = body_graph .make_node ('Unsqueeze' , [padded_range .output [0 ]], attr = {'axes' : [1 ]})
2388
+ unsqueezed_range = GraphBuilder (body_graph ).make_unsqueeze (
2389
+ {'data' : padded_range .output [0 ], 'axes' : [1 ]}, return_node = True )
2389
2390
half_shape_x = body_graph .make_node ('Slice' ,
2390
2391
[new_shape .output [0 ], zeo , minus_two ])
2391
2392
shape_range = body_graph .make_node ('Shape' , [unsqueezed_range .output [0 ]])
2392
2393
full_shape = body_graph .make_node ('Concat' , [half_shape_x .output [0 ], shape_range .output [0 ]], attr = {'axis' : 0 })
2393
2394
expanded_range = body_graph .make_node ('Expand' , [unsqueezed_range .output [0 ], full_shape .output [0 ]])
2394
2395
gathered_input = body_graph .make_node ('GatherElements' , [processed_input .output [0 ], expanded_range .output [0 ]],
2395
2396
attr = {'axis' : - 1 })
2396
- squeezed_input = body_graph .make_node ('Squeeze' , [gathered_input .output [0 ]], attr = {'axes' : [- 1 ]})
2397
+ squeezed_input = GraphBuilder (body_graph ).make_squeeze (
2398
+ {'data' : gathered_input .output [0 ], 'axes' : [- 1 ]}, return_node = True )
2397
2399
left_width = body_graph .make_node ('Sub' , [new_width .output [0 ], abs_k .output [0 ]])
2398
2400
dims = body_graph .make_node ('Concat' , [left_width .output [0 ], new_depth .output [0 ]], attr = {'axis' : 0 })
2399
2401
valid_dim = body_graph .make_node ('ReduceMin' , [dims .output [0 ]])
@@ -2505,8 +2507,8 @@ def normalize():
2505
2507
raw_output_shape + [- 1 ])
2506
2508
squeeze_sliced_graph = ctx .create_new_graph_with_same_config ()
2507
2509
squeeze_sliced_graph .parent_graph = ctx
2508
- squeeze_sliced = squeeze_sliced_graph . make_node ( 'Squeeze' , [ final_output_right_sliced . output [ 0 ]],
2509
- attr = { 'axes' : [- 2 ]})
2510
+ squeeze_sliced = GraphBuilder ( squeeze_sliced_graph ). make_squeeze (
2511
+ { 'data' : final_output_right_sliced . output [ 0 ], 'axes' : [- 2 ]}, return_node = True )
2510
2512
squeeze_sliced_graph .add_graph_output (squeeze_sliced .output [0 ], ctx .get_dtype (node .input [0 ]), raw_output_shape )
2511
2513
shapes = node .output_shapes
2512
2514
dtypes = node .output_dtypes
@@ -2680,14 +2682,14 @@ def version_13(cls, ctx, node, **kwargs):
2680
2682
@tf_op (["MatrixDiag" , "MatrixDiagV2" , "MatrixDiagV3" ])
2681
2683
class MatrixDiag :
2682
2684
@classmethod
2683
- def any_version (cls , opset , ctx , node , ** kwargs ):
2685
+ def version_12 (cls , ctx , node , ** kwargs ):
2684
2686
# Assemble MatrixDiagV3 by ReverseSequence
2685
2687
argc = len (node .input )
2686
2688
2687
- if opset >= 13 :
2688
- squeeze_axes0 = ctx .make_const (utils .make_name ("const_axes" ), np .array ([0 ], dtype = np .int64 ))
2689
- squeeze_axes_1 = ctx .make_const (utils .make_name ("const_axes" ), np .array ([- 1 ], dtype = np .int64 ))
2690
- squeeze_axes_2 = ctx .make_const (utils .make_name ("const_axes" ), np .array ([- 2 ], dtype = np .int64 ))
2689
+ if ctx . opset >= 13 :
2690
+ squeeze_axes0 = ctx .make_const (utils .make_name ("const_axes" ), np .array ([0 ], dtype = np .int64 )). output [ 0 ]
2691
+ squeeze_axes_1 = ctx .make_const (utils .make_name ("const_axes" ), np .array ([- 1 ], dtype = np .int64 )). output [ 0 ]
2692
+ squeeze_axes_2 = ctx .make_const (utils .make_name ("const_axes" ), np .array ([- 2 ], dtype = np .int64 )). output [ 0 ]
2691
2693
2692
2694
minus_two , minus_one , zeo , one , two = \
2693
2695
[n .output [0 ] for n in ctx .make_consts ([[- 2 ], [- 1 ], [0 ], [1 ], [2 ]])]
@@ -2712,7 +2714,7 @@ def processdiag():
2712
2714
diag = node .input [0 ]
2713
2715
shape = ctx .get_shape (diag )
2714
2716
if len (shape ) == 1 :
2715
- if opset < 13 :
2717
+ if ctx . opset < 13 :
2716
2718
diag = mknode ("Unsqueeze" , [diag ], attr = {"axes" : [0 ]})
2717
2719
else :
2718
2720
diag = mknode ("Unsqueeze" , [diag , squeeze_axes0 ])
@@ -2737,7 +2739,7 @@ def id_diag():
2737
2739
def ex_diag ():
2738
2740
g = ctx .create_new_graph_with_same_config ()
2739
2741
g .parent_graph = ctx
2740
- if opset < 13 :
2742
+ if ctx . opset < 13 :
2741
2743
ex = mknode2 (g , "Unsqueeze" , [diag ], attr = {"axes" : [- 2 ]})
2742
2744
else :
2743
2745
ex = mknode2 (g , "Unsqueeze" , [diag , squeeze_axes_2 ])
@@ -2755,7 +2757,7 @@ def squeeze_12(name):
2755
2757
def squeeze_13 (name ):
2756
2758
return ctx .make_node ("Squeeze" , [name , squeeze_axes_1 ]).output [0 ]
2757
2759
2758
- squeeze = squeeze_12 if opset < 13 else squeeze_13
2760
+ squeeze = squeeze_12 if ctx . opset < 13 else squeeze_13
2759
2761
2760
2762
# gather inputs
2761
2763
diag , k , k_min , k_max , k_max_nxt = processdiag ()
@@ -3018,14 +3020,10 @@ def paddiag():
3018
3020
ctx .make_node ("Identity" , [padded ], name = node .name ,
3019
3021
outputs = node .output , shapes = shapes , dtypes = dtypes )
3020
3022
3021
- @classmethod
3022
- def version_12 (cls , ctx , node , ** kwargs ):
3023
- cls .any_version (12 , ctx , node , ** kwargs )
3024
-
3025
3023
@classmethod
3026
3024
def version_13 (cls , ctx , node , ** kwargs ):
3027
3025
# Parameters moved to inputs for operator Squeeze, Unsqueeze.
3028
- cls .any_version ( 13 , ctx , node , ** kwargs )
3026
+ cls .version_12 ( ctx , node , ** kwargs )
3029
3027
3030
3028
3031
3029
@tf_op ("MatrixSetDiagV3" )
0 commit comments