@@ -248,6 +248,7 @@ def version_1(cls, ctx, node, **kwargs):
248
248
# Note: inputs are reversed from what one would expect.
249
249
conv_kernel_shape (ctx , node , 1 )
250
250
input_shape = ctx .get_shape (node .input [2 ])
251
+ output_shape_orig = node .output_shapes
251
252
252
253
# ouput_shape is explicitly specified here, in this case pads values are auto generated/calculated.
253
254
if node .inputs [0 ].is_const ():
@@ -285,7 +286,8 @@ def version_1(cls, ctx, node, **kwargs):
285
286
const_one_two = ctx .make_const (utils .make_name (node .name + "_const_one_two" ),
286
287
np .array ([1 , 2 ], dtype = np .int64 ))
287
288
slice_node = ctx .make_node ("Slice" ,
288
- [node .output [0 ], starts .output [0 ], ends .output [0 ], const_one_two .output [0 ]])
289
+ [node .output [0 ], starts .output [0 ], ends .output [0 ], const_one_two .output [0 ]],
290
+ shapes = output_shape_orig )
289
291
downstream_nodes = ctx .find_output_consumers (node .output [0 ])
290
292
downstream_nodes .remove (output_shape )
291
293
downstream_nodes .remove (slice_node )
0 commit comments