Skip to content

Commit bcc0dad

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into lstm_bp
2 parents ac3370a + 4273b35 commit bcc0dad

File tree

9 files changed

+73
-23
lines changed

9 files changed

+73
-23
lines changed

paddle/capi/gradient_machine.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,18 @@ paddle_error paddle_gradient_machine_create_for_inference_with_parameters(
6464
modelConfigProtobuf.resize(modelConfigSize);
6565
is.read(&modelConfigProtobuf[0], modelConfigSize);
6666
paddle::TrainerConfig config;
67+
paddle::ModelConfig modelConfig;
6768
if (!config.ParseFromString(modelConfigProtobuf) || !config.IsInitialized()) {
68-
return kPD_PROTOBUF_ERROR;
69+
if (!modelConfig.ParseFromString(modelConfigProtobuf) ||
70+
!modelConfig.IsInitialized()) {
71+
return kPD_PROTOBUF_ERROR;
72+
}
73+
} else {
74+
modelConfig = config.model_config();
6975
}
7076
auto ptr = new paddle::capi::CGradientMachine();
7177
ptr->machine.reset(paddle::GradientMachine::create(
72-
config.model_config(), CREATE_MODE_TESTING, {paddle::PARAMETER_VALUE}));
78+
modelConfig, CREATE_MODE_TESTING, {paddle::PARAMETER_VALUE}));
7379
std::vector<paddle::ParameterPtr>& parameters = ptr->machine->getParameters();
7480
for (auto& para : parameters) {
7581
para->load(is);

paddle/operators/cross_entropy_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ or not. But the output only shares the LoD with input `X`.
162162
namespace ops = paddle::operators;
163163
REGISTER_OP(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker,
164164
cross_entropy_grad, ops::CrossEntropyGradientOp);
165-
REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<float>);
165+
REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<float>,
166+
ops::CrossEntropyOpKernel<double>);
166167
REGISTER_OP_CPU_KERNEL(cross_entropy_grad,
167-
ops::CrossEntropyGradientOpKernel<float>);
168+
ops::CrossEntropyGradientOpKernel<float>,
169+
ops::CrossEntropyGradientOpKernel<double>);

paddle/operators/cross_entropy_op.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
108108
} // namespace paddle
109109

110110
namespace ops = paddle::operators;
111-
REGISTER_OP_GPU_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel<float>);
111+
REGISTER_OP_GPU_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel<float>,
112+
ops::CrossEntropyOpCUDAKernel<double>);
112113
REGISTER_OP_GPU_KERNEL(cross_entropy_grad,
113-
ops::CrossEntropyGradientOpCUDAKernel<float>);
114+
ops::CrossEntropyGradientOpCUDAKernel<float>,
115+
ops::CrossEntropyGradientOpCUDAKernel<double>);

paddle/operators/math/cross_entropy.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class CrossEntropyFunctor<platform::CPUPlace, T> {
5454
};
5555

5656
template class CrossEntropyFunctor<platform::CPUPlace, float>;
57+
template class CrossEntropyFunctor<platform::CPUPlace, double>;
5758
} // namespace math
5859
} // namespace operators
5960
} // namespace paddle

paddle/operators/math/cross_entropy.cu

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,36 @@ __device__ __forceinline__ T sum_single_warp(T val) {
3939
return val;
4040
}
4141

42+
// CUDA do not support dynamic arrary in template
43+
// https://stackoverflow.com/questions/20497209
44+
template <typename T>
45+
struct SharedMemory {
46+
// Ensure that we won't compile any un-specialized types
47+
__device__ T* GetPointer() { return NULL; }
48+
};
49+
50+
template <>
51+
struct SharedMemory<float> {
52+
__device__ float* GetPointer() {
53+
extern __shared__ float s_float[];
54+
return s_float;
55+
}
56+
};
57+
58+
template <>
59+
struct SharedMemory<double> {
60+
__device__ double* GetPointer() {
61+
extern __shared__ double s_double[];
62+
return s_double;
63+
}
64+
};
65+
4266
template <typename T>
4367
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
4468
const int class_num) {
4569
int tid = threadIdx.x;
46-
extern __shared__ T d_sum[];
70+
SharedMemory<T> d_sum_shared;
71+
T* d_sum = d_sum_shared.GetPointer();
4772
d_sum[tid] = 0;
4873

4974
int cur_idx = tid;
@@ -102,6 +127,7 @@ class CrossEntropyFunctor<platform::GPUPlace, T> {
102127
};
103128

104129
template class CrossEntropyFunctor<platform::GPUPlace, float>;
130+
template class CrossEntropyFunctor<platform::GPUPlace, double>;
105131
} // namespace math
106132
} // namespace operators
107133
} // namespace paddle

paddle/trainer/MergeModel.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License. */
2020
#include "paddle/utils/PythonUtil.h"
2121

2222
DEFINE_string(model_dir, "", "Directory for separated model files");
23+
DEFINE_string(config_file, "", "Config file for the model");
2324
DEFINE_string(model_file, "", "File for merged model file");
2425

2526
using namespace paddle; // NOLINT
@@ -28,7 +29,8 @@ using namespace std; // NOLINT
2829
int main(int argc, char** argv) {
2930
initMain(argc, argv);
3031
initPython(argc, argv);
31-
string confFile = TrainerConfigHelper::getConfigNameFromPath(FLAGS_model_dir);
32+
33+
string confFile = FLAGS_config_file;
3234
#ifndef PADDLE_WITH_CUDA
3335
FLAGS_use_gpu = false;
3436
#endif

proto/TrainerConfig.proto

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import "ModelConfig.proto";
1919
package paddle;
2020

2121
message OptimizationConfig {
22-
required int32 batch_size = 3;
22+
optional int32 batch_size = 3 [ default = 1 ];
2323
required string algorithm = 4 [ default = "async_sgd" ];
2424
optional int32 num_batches_per_send_parameter = 5 [ default = 1 ];
2525
optional int32 num_batches_per_get_parameter = 6 [ default = 1 ];

python/paddle/v2/framework/tests/op_test.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,15 @@
88
from paddle.v2.framework.framework import Program, OpProtoHolder
99

1010

11+
def randomize_probability(batch_size, class_num, dtype='float32'):
12+
prob = np.random.uniform(
13+
0.1, 1.0, size=(batch_size, class_num)).astype(dtype)
14+
prob_sum = prob.sum(axis=1)
15+
for i in xrange(len(prob)):
16+
prob[i] /= prob_sum[i]
17+
return prob
18+
19+
1120
def grad_var_name(var_name):
1221
return var_name + "@GRAD"
1322

@@ -233,7 +242,7 @@ def create_var(block, name, np_list, var_proto):
233242
if (var_name not in np_list) and var_proto.dispensable:
234243
continue
235244
assert (var_name in np_list) or (var_proto.dispensable), \
236-
"Missing {} as input".format(var_name)
245+
"Missing {} as input".format(var_name)
237246
if var_proto.duplicable:
238247
assert isinstance(np_list[var_name], list), \
239248
"Duplicable {} should be set as list".format(var_name)
@@ -379,16 +388,17 @@ def __assert_is_close(self, numeric_grads, analytic_grads, names,
379388
def err_msg():
380389
offset = np.argmax(diff_mat > max_relative_error)
381390
return ("%s Variable %s max gradient diff %f over limit %f, "
382-
"the first error element is %d") % (
391+
"the first error element is %d, %f, %f") % (
383392
msg_prefix, name, max_diff, max_relative_error,
384-
offset)
393+
offset, a.flatten()[offset], b.flatten()[offset])
385394

386395
self.assertLessEqual(max_diff, max_relative_error, err_msg())
387396

388397
def check_grad(self,
389398
inputs_to_check,
390399
output_names,
391400
no_grad_set=None,
401+
numeric_grad_delta=0.005,
392402
in_place=False,
393403
max_relative_error=0.005,
394404
user_defined_grads=None):
@@ -411,6 +421,7 @@ def check_grad(self,
411421
self.inputs,
412422
input_to_check,
413423
output_names,
424+
delta=numeric_grad_delta,
414425
in_place=in_place) for input_to_check in inputs_to_check
415426
]
416427
grad_names = [

python/paddle/v2/framework/tests/test_cross_entropy_op.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import unittest
22
import numpy as np
3-
from op_test import OpTest
3+
from op_test import OpTest, randomize_probability
44

55

66
class TestCrossEntropyOp1(OpTest):
@@ -12,12 +12,12 @@ def setUp(self):
1212
batch_size = 30
1313
class_num = 10
1414

15-
X = np.random.uniform(0.1, 1.0,
16-
[batch_size, class_num]).astype("float32")
15+
X = randomize_probability(batch_size, class_num, dtype='float64')
16+
1717
label = np.random.randint(0, class_num, (batch_size, 1), dtype="int32")
1818
cross_entropy = np.asmatrix(
1919
[[-np.log(X[i][label[i][0]])] for i in range(X.shape[0])],
20-
dtype="float32")
20+
dtype="float64")
2121

2222
self.inputs = {"X": X, "Label": label}
2323
self.outputs = {"Y": cross_entropy}
@@ -27,7 +27,7 @@ def test_check_output(self):
2727
self.check_output()
2828

2929
def test_check_grad(self):
30-
self.check_grad(["X"], "Y")
30+
self.check_grad(["X"], "Y", numeric_grad_delta=0.001)
3131

3232

3333
class TestCrossEntropyOp2(OpTest):
@@ -39,8 +39,7 @@ def setUp(self):
3939
batch_size = 5
4040
class_num = 37
4141

42-
X = np.random.uniform(0.1, 1.0,
43-
[batch_size, class_num]).astype("float32")
42+
X = randomize_probability(batch_size, class_num)
4443
label = np.random.uniform(0.1, 1.0,
4544
[batch_size, class_num]).astype("float32")
4645
label /= label.sum(axis=1, keepdims=True)
@@ -55,7 +54,8 @@ def test_check_output(self):
5554
self.check_output()
5655

5756
def test_check_grad(self):
58-
self.check_grad(["X"], "Y", max_relative_error=0.05)
57+
self.check_grad(
58+
["X"], "Y", max_relative_error=0.05, numeric_grad_delta=0.001)
5959

6060

6161
class TestCrossEntropyOp3(OpTest):
@@ -67,8 +67,7 @@ def setUp(self):
6767
batch_size = 5
6868
class_num = 17
6969

70-
X = np.random.uniform(0.1, 1.0,
71-
[batch_size, class_num]).astype("float32")
70+
X = randomize_probability(batch_size, class_num)
7271
label_index = np.random.randint(
7372
0, class_num, (batch_size), dtype="int32")
7473
label = np.zeros(X.shape)
@@ -88,7 +87,8 @@ def test_check_output(self):
8887
self.check_output()
8988

9089
def test_check_grad(self):
91-
self.check_grad(["X"], "Y", max_relative_error=0.05)
90+
self.check_grad(
91+
["X"], "Y", max_relative_error=0.05, numeric_grad_delta=0.001)
9292

9393

9494
if __name__ == "__main__":

0 commit comments

Comments
 (0)