@@ -226,21 +226,43 @@ def version_1(cls, ctx, node, **kwargs):
226
226
# Note: inputs are reversed from what one would expect.
227
227
kernel_shape = conv_kernel_shape (ctx , node , 1 )
228
228
input_shape = ctx .get_shape (node .input [2 ])
229
+ append_slice = False
229
230
230
231
# ouput_shape is explicitly specified here, in this case pads values are auto generated/calculated.
231
- output_shape = ctx .get_shape (node .output [0 ])
232
- if node .is_nhwc ():
233
- new_output_shape = [output_shape [1 ], output_shape [2 ]]
234
- input_hw = [input_shape [1 ], input_shape [2 ]]
232
+ if node .inputs [0 ].is_const ():
233
+ output_shape = ctx .get_shape (node .output [0 ])
234
+ if node .is_nhwc ():
235
+ new_output_shape = [output_shape [1 ], output_shape [2 ]]
236
+ input_hw = [input_shape [1 ], input_shape [2 ]]
237
+ else :
238
+ new_output_shape = [output_shape [2 ], output_shape [3 ]]
239
+ input_hw = [input_shape [2 ], input_shape [3 ]]
240
+ utils .make_sure (new_output_shape .count (- 1 ) <= 0 , "output h and w need to be known" )
241
+ utils .make_sure (new_output_shape [0 ] >= input_hw [0 ] and new_output_shape [1 ] >= input_hw [1 ],
242
+ "output h and w cannot be smaller than input h and w." )
243
+ node .set_attr ("output_shape" , new_output_shape )
235
244
else :
236
- new_output_shape = [output_shape [2 ], output_shape [3 ]]
237
- input_hw = [input_shape [2 ], input_shape [3 ]]
238
-
239
- utils .make_sure (new_output_shape .count (- 1 ) <= 0 , "output h and w need to be known" )
240
- utils .make_sure (new_output_shape [0 ] >= input_hw [0 ] and new_output_shape [1 ] >= input_hw [1 ],
241
- "output h and w cannot be smaller than input h and w." )
242
-
243
- node .set_attr ("output_shape" , new_output_shape )
245
+ input_shape = ctx .make_node ("Cast" , [node .input [0 ]], attr = {'to' : TensorProto .INT64 })
246
+ output_shape = ctx .make_node ("Shape" , [node .output [0 ]])
247
+ output_h = GraphBuilder (ctx ).make_slice ({"data" : output_shape .output [0 ], "ends" : [2 ], "starts" : [1 ], "axes" : [0 ]})
248
+ output_w = GraphBuilder (ctx ).make_slice ({"data" : output_shape .output [0 ], "ends" : [3 ], "starts" : [2 ], "axes" : [0 ]})
249
+ expect_h = GraphBuilder (ctx ).make_slice ({"data" : input_shape .output [0 ], "ends" : [2 ], "starts" : [1 ], "axes" : [0 ]})
250
+ expect_w = GraphBuilder (ctx ).make_slice ({"data" : input_shape .output [0 ], "ends" : [3 ], "starts" : [2 ], "axes" : [0 ]})
251
+ diff_h = ctx .make_node ("Sub" , [output_h , expect_h ])
252
+ diff_w = ctx .make_node ("Sub" , [output_w , expect_w ])
253
+ const_two = ctx .make_const (utils .make_name (node .name + "_const_two" ), np .array ([2 ], dtype = np .int64 ))
254
+ start_h = ctx .make_node ("Div" , [diff_h .output [0 ], const_two .output [0 ]])
255
+ start_w = ctx .make_node ("Div" , [diff_w .output [0 ], const_two .output [0 ]])
256
+ end_h = ctx .make_node ("Add" , [start_h .output [0 ], expect_h ])
257
+ end_w = ctx .make_node ("Add" , [start_w .output [0 ], expect_w ])
258
+ starts = ctx .make_node ("Concat" , [start_h .output [0 ], start_w .output [0 ]], attr = {"axis" :0 })
259
+ ends = ctx .make_node ("Concat" , [end_h .output [0 ], end_w .output [0 ]], attr = {"axis" :0 }, name = "concat_efgh" )
260
+ const_one_two = ctx .make_const (utils .make_name (node .name + "_const_one_two" ), np .array ([1 ,2 ], dtype = np .int64 ))
261
+ slice_node = ctx .make_node ("Slice" , [node .output [0 ], starts .output [0 ], ends .output [0 ], const_one_two .output [0 ]])
262
+ downstream_nodes = ctx .find_output_consumers (node .output [0 ])
263
+ downstream_nodes .remove (output_shape )
264
+ downstream_nodes .remove (slice_node )
265
+ ctx .replace_all_inputs (downstream_nodes , node .output [0 ], slice_node .output [0 ])
244
266
245
267
strides = conv_dims_attr (node , "strides" )
246
268
conv_dims_attr (node , "dilations" )
0 commit comments