Skip to content

Commit e712d65

Browse files
committed
trim input
1 parent ef5522d commit e712d65

File tree

2 files changed

+56
-13
lines changed

2 files changed

+56
-13
lines changed

tests/test_backend.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
_INPUT2 = "input2:0"
3939
_TFINPUT3 = "input3"
4040
_INPUT3 = "input3:0"
41+
_TFINPUT4 = "input4"
42+
_INPUT4 = "input4:0"
4143
_TFOUTPUT = "output"
4244
_OUTPUT = "output:0"
4345
_TFOUTPUT1 = "output1"
@@ -109,7 +111,7 @@ def _run_test_case(self, output_names_with_port, feed_dict, **kwargs):
109111
kwargs["convert_var_to_const"] = False
110112
kwargs["constant_fold"] = False
111113
return self.run_test_case(feed_dict, [], output_names_with_port, **kwargs)
112-
114+
'''
113115
def _test_expand_dims_known_rank(self, idx):
114116
tf.reset_default_graph()
115117
x_val = make_xval([3, 4])
@@ -2900,6 +2902,25 @@ def test_unique(self):
29002902
# FIXME: indices in onnx are not the same as in tensorflow so don't check for now
29012903
# self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: x_val})
29022904
self._run_test_case([_OUTPUT], {_INPUT: x_val})
2905+
'''
2906+
2907+
def test_Conv2DBackpropInput_const(self):
2908+
input_sizes_val = np.array([1, 10, 10, 3], dtype=np.int32)
2909+
filter_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 5]).astype(np.float32)
2910+
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 10, 10, 5]).astype(np.float32)
2911+
_ = tf.nn.conv2d_backprop_input(input_sizes=input_sizes_val, filter=filter_val, out_backprop=out_backprop_val, strides=[1,1,1,1], padding='SAME', name=_TFOUTPUT)
2912+
self._run_test_case([_OUTPUT], {})
2913+
2914+
def test_Conv2DBackpropInput(self):
2915+
input_sizes_val = np.array([1, 10, 10, 3], dtype=np.int32)
2916+
input_sizes = tf.placeholder(tf.int32, input_sizes_val.shape, name=_TFINPUT)
2917+
filter_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 5]).astype(np.float32)
2918+
filter = tf.placeholder(tf.float32, filter_val.shape, name=_TFINPUT1)
2919+
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 10, 10, 5]).astype(np.float32)
2920+
out_backprop = tf.placeholder(tf.float32, out_backprop_val.shape, name=_TFINPUT2)
2921+
_ = tf.nn.conv2d_backprop_input(input_sizes, filter, out_backprop, strides=[1,1,1,1], padding='SAME', name=_TFOUTPUT)
2922+
self._run_test_case([_OUTPUT], {_INPUT: input_sizes_val, _INPUT1: filter_val, _INPUT2: out_backprop_val})
2923+
29032924

29042925
if __name__ == '__main__':
29052926
unittest_main()

tf2onnx/onnx_opset/nn.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -226,21 +226,43 @@ def version_1(cls, ctx, node, **kwargs):
226226
# Note: inputs are reversed from what one would expect.
227227
kernel_shape = conv_kernel_shape(ctx, node, 1)
228228
input_shape = ctx.get_shape(node.input[2])
229+
append_slice = False
229230

230231
# ouput_shape is explicitly specified here, in this case pads values are auto generated/calculated.
231-
output_shape = ctx.get_shape(node.output[0])
232-
if node.is_nhwc():
233-
new_output_shape = [output_shape[1], output_shape[2]]
234-
input_hw = [input_shape[1], input_shape[2]]
232+
if node.inputs[0].is_const():
233+
output_shape = ctx.get_shape(node.output[0])
234+
if node.is_nhwc():
235+
new_output_shape = [output_shape[1], output_shape[2]]
236+
input_hw = [input_shape[1], input_shape[2]]
237+
else:
238+
new_output_shape = [output_shape[2], output_shape[3]]
239+
input_hw = [input_shape[2], input_shape[3]]
240+
utils.make_sure(new_output_shape.count(-1) <= 0, "output h and w need to be known")
241+
utils.make_sure(new_output_shape[0] >= input_hw[0] and new_output_shape[1] >= input_hw[1],
242+
"output h and w cannot be smaller than input h and w.")
243+
node.set_attr("output_shape", new_output_shape)
235244
else:
236-
new_output_shape = [output_shape[2], output_shape[3]]
237-
input_hw = [input_shape[2], input_shape[3]]
238-
239-
utils.make_sure(new_output_shape.count(-1) <= 0, "output h and w need to be known")
240-
utils.make_sure(new_output_shape[0] >= input_hw[0] and new_output_shape[1] >= input_hw[1],
241-
"output h and w cannot be smaller than input h and w.")
242-
243-
node.set_attr("output_shape", new_output_shape)
245+
input_shape = ctx.make_node("Cast", [node.input[0]], attr={'to': TensorProto.INT64})
246+
output_shape = ctx.make_node("Shape", [node.output[0]])
247+
output_h = GraphBuilder(ctx).make_slice({"data": output_shape.output[0], "ends": [2], "starts": [1], "axes": [0]})
248+
output_w = GraphBuilder(ctx).make_slice({"data": output_shape.output[0], "ends": [3], "starts": [2], "axes": [0]})
249+
expect_h = GraphBuilder(ctx).make_slice({"data": input_shape.output[0], "ends": [2], "starts": [1], "axes": [0]})
250+
expect_w = GraphBuilder(ctx).make_slice({"data": input_shape.output[0], "ends": [3], "starts": [2], "axes": [0]})
251+
diff_h = ctx.make_node("Sub", [output_h, expect_h])
252+
diff_w = ctx.make_node("Sub", [output_w, expect_w])
253+
const_two = ctx.make_const(utils.make_name(node.name + "_const_two"), np.array([2], dtype=np.int64))
254+
start_h = ctx.make_node("Div", [diff_h.output[0], const_two.output[0]])
255+
start_w = ctx.make_node("Div", [diff_w.output[0], const_two.output[0]])
256+
end_h = ctx.make_node("Add", [start_h.output[0], expect_h])
257+
end_w = ctx.make_node("Add", [start_w.output[0], expect_w])
258+
starts = ctx.make_node("Concat", [start_h.output[0], start_w.output[0]], attr={"axis":0})
259+
ends = ctx.make_node("Concat", [end_h.output[0], end_w.output[0]], attr={"axis":0}, name="concat_efgh")
260+
const_one_two = ctx.make_const(utils.make_name(node.name + "_const_one_two"), np.array([1,2], dtype=np.int64))
261+
slice_node = ctx.make_node("Slice", [node.output[0], starts.output[0], ends.output[0], const_one_two.output[0]])
262+
downstream_nodes = ctx.find_output_consumers(node.output[0])
263+
downstream_nodes.remove(output_shape)
264+
downstream_nodes.remove(slice_node)
265+
ctx.replace_all_inputs(downstream_nodes, node.output[0], slice_node.output[0])
244266

245267
strides = conv_dims_attr(node, "strides")
246268
conv_dims_attr(node, "dilations")

0 commit comments

Comments
 (0)