Skip to content

Commit 339e655

Browse files
committed
refine and add seqconv elementwiseadd relu op test
1 parent e5ce965 commit 339e655

File tree

3 files changed

+164
-69
lines changed

3 files changed

+164
-69
lines changed

paddle/fluid/operators/fusion_seqconv_eltadd_relu_op.cc

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,19 @@ void FusionSeqConvEltAddReluOp::InferShape(
4040

4141
auto x_dims = ctx->GetInputDim("X");
4242
auto w_dims = ctx->GetInputDim("Filter");
43+
int context_length = ctx->Attrs().Get<int>("contextLength");
4344
PADDLE_ENFORCE(
4445
ctx->Attrs().Get<int>("contextStride") == 1,
4546
"Currently, FusionSeqConvEltAddReluOp only supports contextStride=1.");
4647
PADDLE_ENFORCE(x_dims.size() == 2 && w_dims.size() == 2,
4748
"Input(X, Filter) should be 2-D tensor.");
4849
PADDLE_ENFORCE(x_dims.size() == 2 && w_dims.size() == 2,
4950
"Input(X, Filter) should be 2-D tensor.");
50-
PADDLE_ENFORCE(
51-
w_dims[0] == ctx->Attrs().Get<int>("contextLength") * x_dims[1],
52-
"Filter's height should be context_length * "
53-
"input_hidden_size .");
51+
PADDLE_ENFORCE(w_dims[0] == context_length * x_dims[1],
52+
"Filter's height should be context_length * "
53+
"input_hidden_size .");
54+
PADDLE_ENFORCE_GT(context_length + ctx->Attrs().Get<int>("contextStart"), 0,
55+
"contextStart size should be smaller than contextLength.");
5456

5557
ctx->SetOutputDim("Out", {x_dims[0], w_dims[1]});
5658
ctx->SetOutputDim("ColMat", {x_dims[0], w_dims[0]});
@@ -156,9 +158,8 @@ class FusionSeqConvEltAddReluKernel : public framework::OpKernel<T> {
156158
T* dst_data = col_data + st * col_mat_w;
157159
int seq_len = ed - st;
158160
if (seq_len > up_pad + down_pad) {
159-
// zero all up_pad
161+
// zero all up_pad and fill data
160162
std::memset(dst_data, 0, up_pad * col_mat_w_sz);
161-
// fill up_pad data
162163
dst_data = dst_data + up_pad * src_mat_w;
163164
int copy_size = col_mat_w_sz - up_pad * src_mat_w_sz;
164165
for (int j = 0; j < up_pad; ++j) {
@@ -173,9 +174,8 @@ class FusionSeqConvEltAddReluKernel : public framework::OpKernel<T> {
173174
dst_data += col_mat_w;
174175
src_data += src_mat_w;
175176
}
176-
// zero all down_pad
177+
// zero all down_pad and fill data
177178
std::memset(dst_data, 0, down_pad * col_mat_w_sz);
178-
// fill down_pad data
179179
copy_size -= src_mat_w_sz;
180180
for (int j = 0; j < down_pad; ++j) {
181181
std::memcpy(dst_data, src_data, copy_size);
@@ -186,27 +186,29 @@ class FusionSeqConvEltAddReluKernel : public framework::OpKernel<T> {
186186
} else {
187187
PADDLE_ENFORCE_GE(context_length, up_pad + down_pad + 1);
188188
std::memset(dst_data, 0, seq_len * col_mat_w_sz);
189+
dst_data = dst_data + up_pad * src_mat_w;
189190
int zero_sz = up_pad * src_mat_w_sz;
190-
int seq_len_size = seq_len * src_mat_w_sz;
191+
int cur_src_sz = seq_len * src_mat_w_sz;
191192
for (int j = 0; j < std::min(up_pad, seq_len); ++j) {
192-
int copy_size = std::min(seq_len_size, col_mat_w_sz - zero_sz);
193-
std::memcpy(dst_data + zero_sz / sizeof(T), src_data, copy_size);
194-
dst_data += col_mat_w;
193+
int copy_size = std::min(cur_src_sz, col_mat_w_sz - zero_sz);
194+
std::memcpy(dst_data, src_data, copy_size);
195+
dst_data += (col_mat_w - src_mat_w);
195196
zero_sz -= src_mat_w_sz;
196197
}
198+
// from bottom
199+
dst_data = col_data + ed * col_mat_w;
200+
src_data = x_data + st * src_mat_w;
197201
zero_sz = down_pad * src_mat_w_sz;
198-
dst_data = col_data + (ed - 1) * col_mat_w;
199-
src_data = x_data + (ed - up_pad - 1) * src_mat_w;
200-
for (int j = 0; j < std::min(0, seq_len - up_pad); ++j) {
201-
int copy_size = std::min(seq_len_size, col_mat_w_sz - zero_sz);
202-
std::memcpy(dst_data, src_data, copy_size);
202+
for (int j = 1; j <= std::min(down_pad, seq_len); ++j) {
203+
int copy_size = std::min(cur_src_sz, col_mat_w_sz - zero_sz);
204+
std::memcpy(dst_data - (zero_sz + copy_size) / sizeof(T),
205+
src_data + std::max(seq_len - j - up_pad, 0) * src_mat_w,
206+
copy_size);
203207
dst_data -= col_mat_w;
204-
src_data += src_mat_w;
205208
zero_sz -= src_mat_w_sz;
206209
}
207210
}
208211
}
209-
210212
auto& dev_ctx = ctx.template device_context<DeviceContext>();
211213
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
212214
math::FCCompute<DeviceContext, T>(blas, x_dims[0], w_dims[1], w_dims[0],
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import unittest
18+
import numpy as np
19+
import random
20+
from op_test import OpTest
21+
from test_seq_conv import seqconv
22+
23+
24+
class TestSeqConvEltAddRelu(OpTest):
25+
def set_conf(self):
26+
pass
27+
28+
def setUp(self):
29+
self.op_type = 'fusion_seqconv_eltadd_relu'
30+
self.lod = [[6, 4]]
31+
self.in_fea_size = 16
32+
self.out_fea_size = 8
33+
self.context_length = 4
34+
self.context_stride = 1
35+
self.context_start = 0
36+
self.set_conf()
37+
38+
assert self.context_stride == 1
39+
40+
T = sum(self.lod[0])
41+
x = np.random.uniform(-1, 1, [T, self.in_fea_size]).astype('float32')
42+
w = np.random.uniform(
43+
-1, 1, [self.in_fea_size * self.context_length,
44+
self.out_fea_size]).astype('float32')
45+
b = np.random.uniform(-2, 1, [1, self.out_fea_size]).astype('float32')
46+
out = seqconv(x, self.lod, w, self.context_length, self.context_start)
47+
out = np.maximum(out + b, 0)
48+
49+
self.inputs = {'X': (x, self.lod), 'Filter': w, 'Bias': b}
50+
self.attrs = {
51+
'contextStart': self.context_start,
52+
'contextLength': self.context_length,
53+
'contextStride': self.context_stride
54+
}
55+
self.outputs = {'Out': out}
56+
57+
def test_check_output(self):
58+
self.check_output()
59+
60+
61+
class TestSeqConvEltAddReluBS1(TestSeqConvEltAddRelu):
62+
def set_conf(self):
63+
self.lod = [[10]]
64+
65+
66+
class TestSeqConvEltAddReluBS1Case2(TestSeqConvEltAddRelu):
67+
def set_conf(self):
68+
self.lod = [[2]]
69+
70+
71+
class TestSeqConvEltAddReluCase1(TestSeqConvEltAddRelu):
72+
def set_conf(self):
73+
self.lod = [[3, 5, 1, 6]]
74+
self.context_length = 3
75+
self.context_start = -2
76+
77+
78+
class TestSeqConvEltAddReluCase2(TestSeqConvEltAddRelu):
79+
def set_conf(self):
80+
self.lod = [[10, 1, 2, 4, 1, 5, 6]]
81+
self.in_fea_size = 2
82+
self.context_length = 4
83+
self.context_start = -1
84+
85+
86+
class TestSeqConvEltAddReluCase3(TestSeqConvEltAddRelu):
87+
def set_conf(self):
88+
self.lod = [[10, 1, 2, 4, 1, 5, 6]]
89+
self.context_length = 5
90+
self.context_start = -4
91+
92+
93+
if __name__ == '__main__':
94+
unittest.main()

python/paddle/fluid/tests/unittests/test_seq_conv.py

Lines changed: 49 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,53 @@
2020
from op_test import OpTest
2121

2222

23+
def seqconv(x,
24+
lod,
25+
filter,
26+
context_length,
27+
context_start,
28+
padding_trainable=False,
29+
padding_data=None):
30+
[T, M] = x.shape
31+
col = np.zeros((T, context_length * M)).astype('float32')
32+
offset = [0]
33+
for seq_len in lod[0]:
34+
offset.append(offset[-1] + seq_len)
35+
begin_pad = np.max([0, -context_start])
36+
for i in range(len(offset) - 1):
37+
for j in range(context_length):
38+
in_begin = offset[i] + context_start + j
39+
in_end = offset[i + 1] + context_start + j
40+
out_begin = offset[i]
41+
out_end = offset[i + 1]
42+
if in_begin < offset[i]:
43+
pad_size = np.min(
44+
[offset[i] - in_begin, offset[i + 1] - offset[i]])
45+
if padding_trainable:
46+
sub_w = padding_data[j:j + pad_size, :]
47+
col[offset[i]:offset[i] + pad_size, j * M:(j + 1) *
48+
M] = sub_w
49+
out_begin = offset[i] + pad_size
50+
in_begin = offset[i]
51+
52+
if in_end > offset[i + 1]:
53+
pad_size = np.min(
54+
[in_end - offset[i + 1], offset[i + 1] - offset[i]])
55+
if padding_trainable:
56+
sub_w = padding_data[begin_pad + context_start + j -
57+
pad_size:begin_pad + context_start +
58+
j, :]
59+
col[offset[i + 1] - pad_size:offset[i + 1], j * M:(j + 1) *
60+
M] = sub_w
61+
in_end = offset[i + 1]
62+
out_end = offset[i + 1] - pad_size
63+
if in_end <= in_begin:
64+
continue
65+
in_sub = x[in_begin:in_end, :]
66+
col[out_begin:out_end, j * M:(j + 1) * M] += in_sub
67+
return np.dot(col, filter)
68+
69+
2370
class TestSeqProject(OpTest):
2471
def setUp(self):
2572
self.init_test_case()
@@ -66,57 +113,9 @@ def setUp(self):
66113
'paddingTrainable': self.padding_trainable,
67114
'contextStride': self.context_stride
68115
}
69-
out = np.zeros(
70-
(self.input_size[0], self.output_represention)).astype('float32')
116+
out = seqconv(x, self.lod, w, self.context_length, self.context_start,
117+
self.padding_trainable, self.pad_data)
71118
self.outputs = {'Out': out}
72-
self.compute()
73-
74-
def compute(self):
75-
x, lod = self.inputs['X']
76-
filter = self.inputs['Filter']
77-
pading_data = self.pad_data
78-
out = np.zeros((self.input_size[0], self.context_length *
79-
self.input_size[1])).astype('float32')
80-
offset = [0]
81-
for seq_len in lod[0]:
82-
offset.append(offset[-1] + seq_len)
83-
begin_pad = np.max([0, -self.context_start])
84-
85-
for i in range(len(offset) - 1):
86-
for j in range(self.context_length):
87-
in_begin = offset[i] + self.context_start + j
88-
in_end = offset[i + 1] + self.context_start + j
89-
out_begin = offset[i]
90-
out_end = offset[i + 1]
91-
if in_begin < offset[i]:
92-
pad_size = np.min(
93-
[offset[i] - in_begin, offset[i + 1] - offset[i]])
94-
if self.padding_trainable:
95-
sub_w = pading_data[j:j + pad_size, :]
96-
out[offset[i]:offset[i] + pad_size, j * self.input_size[
97-
1]:(j + 1) * self.input_size[1]] = sub_w
98-
out_begin = offset[i] + pad_size
99-
in_begin = offset[i]
100-
101-
if in_end > offset[i + 1]:
102-
pad_size = np.min(
103-
[in_end - offset[i + 1], offset[i + 1] - offset[i]])
104-
if self.padding_trainable:
105-
sub_w = pading_data[begin_pad + self.context_start + j -
106-
pad_size:begin_pad +
107-
self.context_start + j, :]
108-
out[offset[i + 1] - pad_size:offset[i + 1], j * self.
109-
input_size[1]:(j + 1) * self.input_size[1]] = sub_w
110-
in_end = offset[i + 1]
111-
out_end = offset[i + 1] - pad_size
112-
if in_end <= in_begin:
113-
continue
114-
115-
in_sub = x[in_begin:in_end, :]
116-
out[out_begin:out_end, j * self.input_size[1]:(j + 1) *
117-
self.input_size[1]] += in_sub
118-
119-
np.dot(out, filter, out=self.outputs['Out'])
120119

121120
def test_check_output(self):
122121
self.check_output()

0 commit comments

Comments
 (0)