@@ -26,8 +26,7 @@ def any_version(cls, opset, ctx, node, **kwargs):
26
26
node .domain = constants .CONTRIB_OPS_DOMAIN
27
27
for a in list (node .attr .keys ()):
28
28
del node .attr [a ]
29
- unsqueeze_node = GraphBuilder (ctx ).make_squeeze (
30
- {'data' : node .input [1 ], 'axes' : [0 ]}, return_node = True )
29
+ unsqueeze_node = GraphBuilder (ctx ).make_unsqueeze ({'data' : node .input [1 ], 'axes' : [0 ]}, return_node = True )
31
30
32
31
skip_empty_const = ctx .make_const (utils .make_name ('skip_empty_const' ), np .array ([skip_empty ], np .bool ))
33
32
ctx .replace_inputs (node , [node .input [0 ], unsqueeze_node .output [0 ], skip_empty_const .output [0 ]])
@@ -88,8 +87,8 @@ def any_version(cls, opset, ctx, node, **kwargs):
88
87
if ctx .get_shape (inp ) == [] and shape_node is not None :
89
88
expand_node = ctx .make_node ("Expand" , [inp , shape_node .output [0 ]])
90
89
inp = expand_node .output [0 ]
91
- unsqueeze_node = GraphBuilder (ctx ).make_squeeze ({'data' : inp , 'axes' : [0 ]})
92
- unsqueezes .append (unsqueeze_node . output [ 0 ] )
90
+ unsqueeze_node = GraphBuilder (ctx ).make_unsqueeze ({'data' : inp , 'axes' : [0 ]})
91
+ unsqueezes .append (unsqueeze_node )
93
92
stack_node = ctx .make_node ("Concat" , unsqueezes , attr = {'axis' : 0 })
94
93
ctx .replace_inputs (node , [stack_node .output [0 ], separator_node .output [0 ], axis_node .output [0 ]])
95
94
0 commit comments