Skip to content

Commit f8becaa

Browse files
authored
Merge pull request #496 from chinhuang007/add-conv1d
add conv1d support
2 parents 7dff5e7 + 567f344 commit f8becaa

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

tests/test_backend.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2132,6 +2132,40 @@ def test_non_max_suppression(self):
21322132
_ = tf.identity(res2, name=_TFOUTPUT1)
21332133
self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: boxes_val, _INPUT1: scores_val})
21342134

2135+
def _conv1d_test(self, x_val, w, stride=None, padding="VALID", rtol=1e-07):
2136+
if stride is None:
2137+
stride = 1
2138+
tf.reset_default_graph()
2139+
kernel = tf.constant(w, dtype=tf.float32, name='k')
2140+
x = tf.placeholder(tf.float32, shape=x_val.shape, name=_TFINPUT)
2141+
conv = tf.nn.conv1d(x, kernel, stride=stride, padding=padding)
2142+
_ = tf.identity(conv, name=_TFOUTPUT)
2143+
self._run_test_case([_OUTPUT], {_INPUT: x_val}, rtol=rtol)
2144+
2145+
def test_conv1d_1(self):
2146+
x_val = make_xval((1, 7, 1))
2147+
w = np.array([2., 1., 3.], dtype=np.float32).reshape(3, 1, 1)
2148+
self._conv1d_test(x_val, w)
2149+
2150+
def test_conv1d_2(self):
2151+
x_val = make_xval((1, 7, 1))
2152+
w = np.array([2., 1., 3.], dtype=np.float32).reshape(3, 1, 1)
2153+
self._conv1d_test(x_val, w, stride=2)
2154+
2155+
def test_conv1d_3(self):
2156+
x_val = make_xval((1, 7, 1))
2157+
w = np.array([2., 1., 3.], dtype=np.float32).reshape(3, 1, 1)
2158+
self._conv1d_test(x_val, w, padding="SAME")
2159+
2160+
def test_conv1d_4(self):
2161+
x_val = make_xval((1, 7, 1))
2162+
w = np.array([2., 1., 3.], dtype=np.float32).reshape(3, 1, 1)
2163+
self._conv1d_test(x_val, w, rtol=1e-05)
2164+
2165+
def test_conv1d_5(self):
2166+
x_val = make_xval((1, 7, 1))
2167+
w = np.array([3., 3., 3.], dtype=np.float32).reshape(3, 1, 1)
2168+
self._conv1d_test(x_val, w)
21352169

21362170
if __name__ == '__main__':
21372171
unittest_main()

tf2onnx/onnx_opset/nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def conv_kernel_shape(ctx, node, input_idx, spatial=2):
197197
return kernel_shape
198198

199199

200-
@tf_op(["Conv2D", "Conv3D"])
200+
@tf_op(["Conv1D", "Conv2D", "Conv3D"])
201201
class ConvOp:
202202
@classmethod
203203
def version_4(cls, ctx, node, **kwargs):

0 commit comments

Comments
 (0)