Skip to content

Commit ac3370a

Browse files
committed
Add unit testing for gemv and fix the gradien check for bais.
1 parent 2e02987 commit ac3370a

File tree

5 files changed

+165
-42
lines changed

5 files changed

+165
-42
lines changed

paddle/framework/lod_tensor_test.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@ TEST(LoDTensor, LoDInGPU) {
3636
lod_tensor.mutable_data<float>(place);
3737

3838
lod_tensor.set_lod(src_lod);
39-
CHECK_EQ(lod_tensor.lod_element(0, 2).first, 4UL);
40-
CHECK_EQ(lod_tensor.lod_element(0, 4).first, 8UL);
39+
EXPECT_EQ(lod_tensor.lod_element(0, 2).first, 4UL);
40+
EXPECT_EQ(lod_tensor.lod_element(0, 4).first, 8UL);
4141

4242
auto lod = lod_tensor.lod();
4343

4444
test<<<1, 8>>>(lod[0].data(), lod[0].size());
4545
cudaDeviceSynchronize();
4646

4747
for (size_t i = 0; i < src_lod[0].size(); ++i) {
48-
CHECK_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2);
48+
EXPECT_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2);
4949
}
50-
}
50+
}

paddle/operators/lstm_op.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,9 @@ class LSTMGradKernel : public framework::OpKernel<T> {
162162
auto* bias_g = ctx.Output<Tensor>(framework::GradVarName("Bias"));
163163

164164
auto& device_ctx = ctx.device_context();
165+
math::SetConstant<Place, T> zero;
165166
if (weight_g) {
166167
weight_g->mutable_data<T>(ctx.GetPlace());
167-
math::SetConstant<Place, T> zero;
168168
zero(device_ctx, weight_g, static_cast<T>(0.0));
169169
}
170170

@@ -188,6 +188,7 @@ class LSTMGradKernel : public framework::OpKernel<T> {
188188
math::LstmMetaGrad<T> lstm_grad;
189189
if (bias && bias_g) {
190190
T* bias_g_data = const_cast<T*>(bias_g->mutable_data<T>(ctx.GetPlace()));
191+
zero(device_ctx, bias_g, static_cast<T>(0.0));
191192
lstm_grad.checkIgGrad = bias_g_data + 4 * frame_size;
192193
lstm_grad.checkFgGrad = lstm_grad.checkIgGrad + frame_size;
193194
lstm_grad.checkOgGrad = lstm_grad.checkFgGrad + frame_size;
@@ -219,6 +220,8 @@ class LSTMGradKernel : public framework::OpKernel<T> {
219220
batch_cell_g.mutable_data<T>(out_dims, ctx.GetPlace());
220221
batch_cell_g.set_lod(batch_gate->lod());
221222
to_batch(device_ctx, *cell_g, batch_cell_g, false);
223+
// TODO(qingqing) support the case output cell has gradient.
224+
zero(device_ctx, &batch_cell_g, static_cast<T>(0.0));
222225

223226
LoDTensor batch_gate_g;
224227
batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace());
@@ -304,7 +307,7 @@ class LSTMGradKernel : public framework::OpKernel<T> {
304307
int n = static_cast<int>(batch_gate_g.dims()[1]);
305308

306309
Tensor ones;
307-
ones.mutable_data<T>({1, m}, ctx.GetPlace());
310+
ones.mutable_data<T>({m}, ctx.GetPlace());
308311
math::SetConstant<Place, T> set;
309312
set(device_ctx, &ones, static_cast<T>(1.0));
310313

paddle/operators/math/math_function_test.cc

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,53 @@ TEST(math_function, zero) {
8989
EXPECT_EQ(t[2], 1);
9090
EXPECT_EQ(t[3], 1);
9191
}
92+
93+
template <typename T>
94+
void GemvTest(int m, int n, bool trans) {
95+
paddle::framework::Tensor mat_a;
96+
paddle::framework::Tensor vec_b;
97+
paddle::framework::Tensor vec_c;
98+
auto* cpu_place = new paddle::platform::CPUPlace();
99+
int b_num = trans ? m : n;
100+
int c_num = trans ? n : m;
101+
102+
T* data_a = mat_a.mutable_data<T>({m, n}, *cpu_place);
103+
T* data_b = vec_b.mutable_data<T>({b_num}, *cpu_place);
104+
T* data_c = vec_c.mutable_data<T>({c_num}, *cpu_place);
105+
for (int i = 0; i < mat_a.numel(); ++i) {
106+
data_a[i] = static_cast<T>(i);
107+
}
108+
for (int i = 0; i < vec_b.numel(); ++i) {
109+
data_b[i] = static_cast<T>(i);
110+
}
111+
112+
paddle::platform::CPUDeviceContext context(*cpu_place);
113+
paddle::operators::math::gemv<paddle::platform::CPUPlace, T>(
114+
context, trans, static_cast<int>(m), static_cast<int>(n), 1., data_a,
115+
data_b, 0., data_c);
116+
117+
if (!trans) {
118+
for (int i = 0; i < m; ++i) {
119+
T sum = 0.0;
120+
for (int j = 0; j < n; ++j) {
121+
sum += data_a[i * n + j] * data_b[j];
122+
}
123+
ASSERT_FLOAT_EQ(data_c[i], sum);
124+
}
125+
} else {
126+
for (int i = 0; i < n; ++i) {
127+
T sum = 0.0;
128+
for (int j = 0; j < m; ++j) {
129+
sum += data_a[j * n + i] * data_b[j];
130+
}
131+
ASSERT_FLOAT_EQ(data_c[i], sum);
132+
}
133+
}
134+
}
135+
136+
TEST(math_function, gemv) {
137+
GemvTest<float>(3, 13, false);
138+
GemvTest<double>(4, 5, false);
139+
GemvTest<float>(12, 7, true);
140+
GemvTest<double>(7, 9, true);
141+
}

paddle/operators/math/math_function_test.cu

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,65 @@ TEST(math_function, gemm_trans_cublas) {
177177
EXPECT_EQ(input3_ptr[7], 99);
178178
delete gpu_place;
179179
}
180+
181+
template <typename T>
182+
void GemvTest(int m, int n, bool trans) {
183+
paddle::framework::Tensor mat_a;
184+
paddle::framework::Tensor vec_b;
185+
paddle::framework::Tensor vec_c;
186+
auto* cpu_place = new paddle::platform::CPUPlace();
187+
188+
T* data_a = mat_a.mutable_data<T>({m, n}, *cpu_place);
189+
T* data_b = vec_b.mutable_data<T>({trans ? m : n}, *cpu_place);
190+
T* data_c = vec_c.mutable_data<T>({trans ? n : m}, *cpu_place);
191+
192+
auto* gpu_place = new paddle::platform::GPUPlace(0);
193+
paddle::framework::Tensor g_mat_a;
194+
paddle::framework::Tensor g_vec_b;
195+
paddle::framework::Tensor g_vec_c;
196+
T* g_data_a = g_mat_a.mutable_data<T>(mat_a.dims(), *gpu_place);
197+
T* g_data_b = g_vec_b.mutable_data<T>(vec_b.dims(), *gpu_place);
198+
T* g_data_c = g_vec_c.mutable_data<T>(vec_c.dims(), *gpu_place);
199+
200+
for (int i = 0; i < mat_a.numel(); ++i) {
201+
data_a[i] = static_cast<T>(i);
202+
}
203+
for (int i = 0; i < vec_b.numel(); ++i) {
204+
data_b[i] = static_cast<T>(i);
205+
}
206+
207+
paddle::platform::CUDADeviceContext context(*gpu_place);
208+
g_mat_a.CopyFrom(mat_a, *gpu_place, context);
209+
g_vec_b.CopyFrom(vec_b, *gpu_place, context);
210+
211+
paddle::operators::math::gemv<paddle::platform::GPUPlace, T>(
212+
context, trans, static_cast<int>(m), static_cast<int>(n), 1., g_data_a,
213+
g_data_b, 0., g_data_c);
214+
215+
vec_c.CopyFrom(g_vec_c, paddle::platform::CPUPlace(), context);
216+
217+
if (!trans) {
218+
for (int i = 0; i < m; ++i) {
219+
T sum = 0.0;
220+
for (int j = 0; j < n; ++j) {
221+
sum += data_a[i * n + j] * data_b[j];
222+
}
223+
ASSERT_FLOAT_EQ(data_c[i], sum);
224+
}
225+
} else {
226+
for (int i = 0; i < n; ++i) {
227+
T sum = 0.0;
228+
for (int j = 0; j < m; ++j) {
229+
sum += data_a[j * n + i] * data_b[j];
230+
}
231+
ASSERT_FLOAT_EQ(data_c[i], sum);
232+
}
233+
}
234+
}
235+
236+
TEST(math_function, gemv) {
237+
GemvTest<float>(3, 13, false);
238+
GemvTest<double>(3, 13, false);
239+
GemvTest<float>(3, 13, true);
240+
GemvTest<double>(3, 13, true);
241+
}

python/paddle/v2/framework/tests/test_lstm_op.py

Lines changed: 44 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -114,26 +114,20 @@ def _reverse(x, lod):
114114

115115

116116
class TestLstmOp(OpTest):
117-
def set_data(self):
118-
# self.lod = [[0, 2, 6, 9]]
119-
# self.D = 64
120-
# self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]
121-
122-
self.lod = [[0, 1]]
123-
self.D = 4
124-
self.sort_idx = [0]
125-
126-
# self.act_gate = 'identity'
127-
# self.act_cell = 'identity'
128-
# self.act_cand = 'identity'
117+
def set_argument(self):
118+
self.lod = [[0, 2, 6, 9]]
119+
self.D = 16
120+
self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]
121+
129122
self.act_gate = 'sigmoid'
130123
self.act_cell = 'tanh'
131124
self.act_cand = 'tanh'
132125

126+
self.has_initial_state = True
133127
self.is_reverse = False
134128

135129
def setUp(self):
136-
self.set_data()
130+
self.set_argument()
137131
self.op_type = 'lstm'
138132

139133
T = self.lod[0][-1]
@@ -155,17 +149,14 @@ def setUp(self):
155149
for i, j in enumerate(self.sort_idx):
156150
g_sort[i, :] = g[j, :]
157151

158-
self.inputs = {
159-
'Input': (x, self.lod),
160-
'H0': h0,
161-
'C0': c0,
162-
'Weight': w,
163-
'Bias': b
164-
}
152+
self.inputs = {'Input': (x, self.lod), 'Weight': w, 'Bias': b}
153+
self.inputs['H0'] = h0
154+
self.inputs['C0'] = c0
155+
165156
self.outputs = {
166157
'Hidden': (h, self.lod),
167158
'Cell': (c, self.lod),
168-
#'BatchGate': g_sort,
159+
'BatchGate': g_sort,
169160
}
170161
self.attrs = {
171162
'usePeepholes': True,
@@ -175,26 +166,43 @@ def setUp(self):
175166
'candidateActivation': self.act_cand
176167
}
177168

178-
def not_test_check_output(self):
169+
def test_check_output(self):
179170
self.check_output()
180171

172+
#TODO(qingqing) add more unit testing case
181173
def test_check_grad(self):
174+
# TODO(qingqing) remove folowing two lines after the check_grad is refined.
182175
self.outputs['BatchGate'] = None
183176
self.outputs['BatchCellPreAct'] = None
184-
self.check_grad(['Input', 'Weight'], ['Hidden', 'Cell'])
185-
#['Input', 'Weight', 'Bias'], ['Hidden', 'Cell'])
186-
187-
#class TestLstmOpRerverse(TestLstmOp):
188-
# def set_data(self):
189-
# self.lod = [[0, 2, 6, 9]]
190-
# self.D = 64
191-
# self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]
192-
#
193-
# self.act_gate = 'sigmoid'
194-
# self.act_cell = 'tanh'
195-
# self.act_cand = 'tanh'
196-
#
197-
# self.is_reverse = True
177+
self.check_grad(['Input', 'Weight', 'Bias'], ['Hidden'])
178+
179+
180+
class TestLstmOpHasNoInitial(TestLstmOp):
181+
def set_argument(self):
182+
self.lod = [[0, 2, 6, 9]]
183+
self.D = 64
184+
self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]
185+
186+
self.act_gate = 'sigmoid'
187+
self.act_cell = 'tanh'
188+
self.act_cand = 'tanh'
189+
190+
self.has_initial_state = False
191+
self.is_reverse = True
192+
193+
194+
class TestLstmOpRerverse(TestLstmOp):
195+
def set_argument(self):
196+
self.lod = [[0, 2, 6, 9]]
197+
self.D = 64
198+
self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]
199+
200+
self.act_gate = 'sigmoid'
201+
self.act_cell = 'tanh'
202+
self.act_cand = 'tanh'
203+
204+
self.has_initial_state = True
205+
self.is_reverse = True
198206

199207

200208
if __name__ == '__main__':

0 commit comments

Comments
 (0)