Skip to content

Commit cb14b0d

Browse files
committed
merge release/1.0.0
2 parents ef4ceec + b97257b commit cb14b0d

File tree

6 files changed

+193
-43
lines changed

6 files changed

+193
-43
lines changed

paddle/fluid/operators/uniform_random_op.cc

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@ namespace operators {
2323
template <typename T>
2424
class CPUUniformRandomKernel : public framework::OpKernel<T> {
2525
public:
26-
void Compute(const framework::ExecutionContext& ctx) const override {
27-
framework::Tensor* tensor = nullptr;
26+
void Compute(const framework::ExecutionContext &ctx) const override {
27+
framework::Tensor *tensor = nullptr;
2828
auto out_var = ctx.OutputVar("Out");
2929
if (out_var->IsType<framework::LoDTensor>()) {
3030
tensor = out_var->GetMutable<framework::LoDTensor>();
3131
} else if (out_var->IsType<framework::SelectedRows>()) {
3232
auto shape = ctx.Attr<std::vector<int>>("shape");
33-
auto* selected_rows = out_var->GetMutable<framework::SelectedRows>();
33+
auto *selected_rows = out_var->GetMutable<framework::SelectedRows>();
3434
tensor = selected_rows->mutable_value();
3535
tensor->Resize(framework::make_ddim(shape));
3636
selected_rows->mutable_rows()->reserve(shape[0]);
@@ -39,7 +39,7 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> {
3939
"uniform_random_op's output only"
4040
"supports SelectedRows and LoDTensor");
4141
}
42-
T* data = tensor->mutable_data<T>(ctx.GetPlace());
42+
T *data = tensor->mutable_data<T>(ctx.GetPlace());
4343
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
4444
std::minstd_rand engine;
4545
if (seed == 0) {
@@ -60,14 +60,14 @@ class UniformRandomOp : public framework::OperatorWithKernel {
6060
public:
6161
using framework::OperatorWithKernel::OperatorWithKernel;
6262

63-
void InferShape(framework::InferShapeContext* ctx) const override {
63+
void InferShape(framework::InferShapeContext *ctx) const override {
6464
PADDLE_ENFORCE(ctx->HasOutput("Out"),
6565
"Output(Out) of UniformRandomOp should not be null.");
6666

6767
PADDLE_ENFORCE(
6868
ctx->Attrs().Get<float>("min") < ctx->Attrs().Get<float>("max"),
6969
"uniform_random's min must less then max");
70-
auto& shape = ctx->Attrs().Get<std::vector<int>>("shape");
70+
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape");
7171
std::vector<int64_t> temp;
7272
temp.reserve(shape.size());
7373
for (auto dim : shape) {
@@ -78,7 +78,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
7878

7979
protected:
8080
framework::OpKernelType GetExpectedKernelType(
81-
const framework::ExecutionContext& ctx) const override {
81+
const framework::ExecutionContext &ctx) const override {
8282
return framework::OpKernelType(
8383
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
8484
ctx.GetPlace());
@@ -112,17 +112,17 @@ uniform distribution. The random result is in set [min, max].
112112

113113
class UniformRandomOpVarTypeInference : public framework::VarTypeInference {
114114
public:
115-
void operator()(const framework::OpDesc& op_desc,
116-
framework::BlockDesc* block) const override {
115+
void operator()(const framework::OpDesc &op_desc,
116+
framework::BlockDesc *block) const override {
117117
auto out_var_name = op_desc.Output("Out").front();
118-
if (block->FindRecursiveOrCreateVar(out_var_name).GetType() ==
119-
framework::proto::VarType::SELECTED_ROWS) {
120-
block->FindRecursiveOrCreateVar(out_var_name)
121-
.SetType(framework::proto::VarType::SELECTED_ROWS);
122-
} else {
123-
block->FindRecursiveOrCreateVar(out_var_name)
124-
.SetType(framework::proto::VarType::LOD_TENSOR);
118+
auto var_data_type = static_cast<framework::proto::VarType::Type>(
119+
boost::get<int>(op_desc.GetAttr("dtype")));
120+
121+
auto out_var = block->FindRecursiveOrCreateVar(out_var_name);
122+
if (out_var.GetType() != framework::proto::VarType::SELECTED_ROWS) {
123+
out_var.SetType(framework::proto::VarType::LOD_TENSOR);
125124
}
125+
out_var.SetDataType(var_data_type);
126126
}
127127
};
128128

paddle/fluid/pybind/pybind.cc

Lines changed: 140 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,50 @@ PYBIND11_PLUGIN(core) {
156156
.def("_get_double_element", TensorGetElement<double>)
157157
.def("_dtype", [](Tensor &self) { return ToDataType(self.type()); });
158158

159-
py::class_<LoDTensor, Tensor>(m, "LoDTensor")
159+
py::class_<LoDTensor, Tensor>(m, "LoDTensor", R"DOC(
160+
LoDTensor is a Tensor with optional LoD information.
161+
162+
np.array(lod_tensor) can convert LoDTensor to numpy array.
163+
lod_tensor.lod() can retrieve the LoD information.
164+
165+
LoD is short for Level of Details and is usually used for varied sequence
166+
length. You can skip the following comment if you don't need optional LoD.
167+
168+
For example:
169+
A LoDTensor X can look like the example below. It contains 2 sequences.
170+
The first has length 2 and the second has length 3, as described by x.lod.
171+
172+
The first tensor dimension 5=2+3 is calculated from LoD if it's available.
173+
It means the total number of sequence element. In X, each element has 2
174+
columns, hence [5, 2].
175+
176+
x.lod = [[2, 3]]
177+
x.data = [[1, 2], [3, 4], // seq 1
178+
[5, 6], [7, 8], [9, 10]] // seq 2
179+
x.shape = [5, 2]
180+
181+
LoD can have multiple levels (for example, a paragraph can have multiple
182+
sentences and a sentence can have multiple words). In the following
183+
LodTensor Y, the lod_level is 2. It means there are 2 sequence, the
184+
first sequence length is 2 (has 2 sub-sequences), the second one's
185+
length is 1. The first sequence's 2 sub-sequences have length 2 and 2,
186+
respectively. And the second sequence's 1 sub-sequence has length 3.
187+
188+
y.lod = [[2 1], [2 2 3]]
189+
y.shape = [2+2+3, ...]
190+
191+
Note:
192+
In above description, LoD is length-based. In Paddle internal
193+
implementation, lod is offset-based. Hence, internally,
194+
y.lod is represented as [[0, 2, 3], [0, 2, 4, 7]] (length-based
195+
equivlent would be [[2-0, 3-2], [2-0, 4-2, 7-4]]).
196+
197+
Sometimes LoD is called recursive_sequence_length to be more
198+
self-explanatory. In this case, it must be length-based. Due to history
199+
reasons. when LoD is called lod in public API, it might be offset-based.
200+
Users should be careful about it.
201+
202+
)DOC")
160203
.def_buffer(
161204
[](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); })
162205
.def("__init__",
@@ -596,34 +639,78 @@ All parameter, weight, gradient are variables in Paddle.
596639

597640
// -- python binds for parallel executor.
598641
py::class_<ParallelExecutor> pe(m, "ParallelExecutor");
599-
py::class_<ExecutionStrategy> exec_strategy(pe, "ExecutionStrategy");
642+
py::class_<ExecutionStrategy> exec_strategy(pe, "ExecutionStrategy", R"DOC(
643+
ExecutionStrategy allows the user to more preciously control how to run
644+
the program in ParallelExecutor by setting the property.
645+
646+
Examples:
647+
.. code-block:: python
648+
649+
exec_strategy = fluid.ExecutionStrategy()
650+
exec_strategy.num_threads = 4
651+
652+
train_exe = fluid.ParallelExecutor(use_cuda=True,
653+
loss_name=loss.name,
654+
exec_strategy=exec_strategy)
655+
656+
train_loss, = train_exe.run([loss.name], feed=feed_dict)
657+
658+
)DOC");
659+
600660
exec_strategy.def(py::init())
601661
.def_property(
602662
"num_threads",
603663
[](const ExecutionStrategy &self) { return self.num_threads_; },
604664
[](ExecutionStrategy &self, size_t num_threads) {
605665
self.num_threads_ = num_threads;
606-
})
666+
},
667+
R"DOC(The type is INT, num_threads represents the size of thread pool that
668+
used to run the operators of the current program in ParallelExecutor.
669+
If :math:`num\_threads=1`, all the operators will execute one by one,
670+
but the order maybe difference between iterations.
671+
If it is not set, it will be set in ParallelExecutor according to the
672+
device type and device count, for GPU, :math:`num\_threads=device\_count*4`, for CPU,
673+
:math:`num\_threads=CPU\_NUM*4`, the explanation of:math:`CPU\_NUM` is in ParallelExecutor.
674+
if it is not set, ParallelExecutor will get the cpu count by calling
675+
`multiprocessing.cpu_count()`. Default 0.)DOC")
607676
.def_property(
608677
"use_cuda",
609678
[](const ExecutionStrategy &self) { return self.use_cuda_; },
610679
[](ExecutionStrategy &self, bool use_cuda) {
611680
self.use_cuda_ = use_cuda;
612-
})
681+
}) // FIXME(chengduo): Doesn't add doc for 'use_cuda', use_cuda may
682+
// make user confuse, because ParallelExecutor has a parameter named
683+
// 'use_cuda' too, in current implementation, ParallelExecutor's
684+
// 'use_cuda' will rewrite ExecutionStrategy's 'use_cuda'.
613685
.def_property(
614686
"allow_op_delay",
615687
[](const ExecutionStrategy &self) { return self.allow_op_delay_; },
616688
[](ExecutionStrategy &self, bool allow_op_delay) {
617689
self.allow_op_delay_ = allow_op_delay;
618-
})
690+
},
691+
R"DOC(The type is BOOL, allow_op_delay represents whether to delay the
692+
communication operators to run, it may make the execution faster.
693+
Note that in some models, allow_op_delay may cause program hang. Default False.)DOC")
619694
.def_property(
620695
"num_iteration_per_drop_scope",
621696
[](const ExecutionStrategy &self) {
622697
return self.num_iteration_per_drop_scope_;
623698
},
624699
[](ExecutionStrategy &self, size_t num_iteration_per_drop_scope) {
625700
self.num_iteration_per_drop_scope_ = num_iteration_per_drop_scope;
626-
});
701+
},
702+
R"DOC(The type is INT, num_iteration_per_drop_scope indicates how
703+
many iterations to clean up the temp variables which
704+
is generated during execution. It may make the execution faster,
705+
because the temp variable's shape maybe the same between two iterations. Default 100.
706+
707+
NOTES:
708+
1. If you fetch data when calling the 'run', the ParallelExecutor
709+
will clean up the temp variables at the end of the current iteration.
710+
2. In some NLP model, it may cause the GPU memory is insufficient,
711+
in this case, you should reduce `num_iteration_per_drop_scope`.
712+
)DOC");
713+
627714
exec_strategy.def_property(
628715
"use_experimental_executor",
629716
[](const ExecutionStrategy &self) {
@@ -634,7 +721,22 @@ All parameter, weight, gradient are variables in Paddle.
634721
: ExecutionStrategy::kDefault;
635722
});
636723

637-
py::class_<BuildStrategy> build_strategy(pe, "BuildStrategy");
724+
py::class_<BuildStrategy> build_strategy(pe, "BuildStrategy", R"DOC(
725+
BuildStrategy allows the user to more preciously control how to
726+
build the SSA Graph in ParallelExecutor by setting the property.
727+
728+
Examples:
729+
.. code-block:: python
730+
731+
build_strategy = fluid.BuildStrategy()
732+
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
733+
734+
train_exe = fluid.ParallelExecutor(use_cuda=True,
735+
loss_name=loss.name,
736+
build_strategy=build_strategy)
737+
738+
train_loss, = train_exe.run([loss.name], feed=feed_dict)
739+
)DOC");
638740

639741
py::enum_<BuildStrategy::ReduceStrategy>(build_strategy, "ReduceStrategy")
640742
.value("Reduce", BuildStrategy::ReduceStrategy::kReduce)
@@ -652,31 +754,51 @@ All parameter, weight, gradient are variables in Paddle.
652754
[](const BuildStrategy &self) { return self.reduce_; },
653755
[](BuildStrategy &self, BuildStrategy::ReduceStrategy strategy) {
654756
self.reduce_ = strategy;
655-
})
757+
},
758+
R"DOC(The type is STR, there are two reduce strategies in ParallelExecutor,
759+
'AllReduce' and 'Reduce'. If you want that all the parameters'
760+
optimization are done on all devices independently, you should choose 'AllReduce';
761+
if you choose 'Reduce', all the parameters' optimization will be evenly distributed
762+
to different devices, and then broadcast the optimized parameter to other devices.
763+
In some models, `Reduce` is faster. Default 'AllReduce'. )DOC")
656764
.def_property(
657765
"gradient_scale_strategy",
658766
[](const BuildStrategy &self) { return self.gradient_scale_; },
659767
[](BuildStrategy &self,
660768
BuildStrategy::GradientScaleStrategy strategy) {
661769
self.gradient_scale_ = strategy;
662-
})
770+
},
771+
R"DOC(The type is STR, there are three ways of defining :math:`loss@grad` in
772+
ParallelExecutor, 'CoeffNumDevice', 'One' and 'Customized'. By default,
773+
ParallelExecutor sets the :math:`loss@grad` according to the number of devices.
774+
If you want to customize :math:`loss@grad`, you can choose 'Customized'.
775+
Default 'CoeffNumDevice'.)DOC")
663776
.def_property(
664777
"debug_graphviz_path",
665778
[](const BuildStrategy &self) { return self.debug_graphviz_path_; },
666779
[](BuildStrategy &self, const std::string &path) {
667780
self.debug_graphviz_path_ = path;
668-
})
781+
},
782+
R"DOC(The type is STR, debug_graphviz_path indicate the path that
783+
writing the SSA Graph to file in the form of graphviz, you.
784+
It is useful for debugging. Default "")DOC")
669785
.def_property(
670786
"enable_data_balance",
671787
[](const BuildStrategy &self) { return self.enable_data_balance_; },
672-
[](BuildStrategy &self, bool b) { self.enable_data_balance_ = b; })
673-
.def_property("fuse_elewise_add_act_ops",
674-
[](const BuildStrategy &self) {
675-
return self.fuse_elewise_add_act_ops_;
676-
},
677-
[](BuildStrategy &self, bool b) {
678-
self.fuse_elewise_add_act_ops_ = b;
679-
});
788+
[](BuildStrategy &self, bool b) {
789+
self.enable_data_balance_ = b;
790+
}) // FIXME(chengudo): enable_data_balance seems not important
791+
.def_property(
792+
"fuse_elewise_add_act_ops",
793+
[](const BuildStrategy &self) {
794+
return self.fuse_elewise_add_act_ops_;
795+
},
796+
[](BuildStrategy &self, bool b) {
797+
self.fuse_elewise_add_act_ops_ = b;
798+
},
799+
R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether
800+
to fuse elementwise_add_op and activation_op,
801+
it may make the execution faster. Default False)DOC");
680802

681803
pe.def(py::init<const std::vector<platform::Place> &,
682804
const std::unordered_set<std::string> &,

python/paddle/fluid/layers/io.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,11 @@ def data(name,
5555
Args:
5656
name(str): The name/alias of the function
5757
shape(list): Tuple declaring the shape.
58-
append_batch_size(bool): Whether or not to append the data as a batch.
58+
append_batch_size(bool):
59+
1. If true, it prepends -1 to the shape.
60+
For example if shape=[1], the resulting shape is [-1, 1].
61+
2. If shape contains -1, such as shape=[1, -1],
62+
append_batch_size will be enforced to be be False (ineffective).
5963
dtype(int|float): The type of data : float32, float_16, int etc
6064
type(VarType): The output type. By default it is LOD_TENSOR.
6165
lod_level(int): The LoD Level. 0 means the input data is not a sequence.

python/paddle/fluid/layers/ops.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from __future__ import print_function
1616
from .layer_function_generator import generate_layer_fn, generate_layer_fn_noattr
17+
from .. import core
18+
from ..framework import convert_np_dtype_to_dtype_
1719

1820
__activations_noattr__ = [
1921
'sigmoid',
@@ -58,8 +60,11 @@
5860

5961

6062
def uniform_random(shape, dtype=None, min=None, max=None, seed=None):
63+
locals_var = locals().keys()
64+
if not isinstance(dtype, core.VarDesc.VarType):
65+
dtype = convert_np_dtype_to_dtype_(dtype)
6166
kwargs = dict()
62-
for name in locals():
67+
for name in locals_var:
6368
val = locals()[name]
6469
if val is not None:
6570
kwargs[name] = val
@@ -78,8 +83,9 @@ def uniform_random(shape, dtype=None, min=None, max=None, seed=None):
7883

7984

8085
def hard_shrink(x, threshold=None):
86+
locals_var = locals().keys()
8187
kwargs = dict()
82-
for name in locals():
88+
for name in locals_var:
8389
val = locals()[name]
8490
if val is not None:
8591
kwargs[name] = val
@@ -99,12 +105,12 @@ def hard_shrink(x, threshold=None):
99105

100106

101107
def cumsum(x, axis=None, exclusive=None, reverse=None):
108+
locals_var = locals().keys()
102109
kwargs = dict()
103-
for name in locals():
110+
for name in locals_var:
104111
val = locals()[name]
105112
if val is not None:
106113
kwargs[name] = val
107-
108114
return _cum_sum_(**kwargs)
109115

110116

@@ -121,8 +127,9 @@ def cumsum(x, axis=None, exclusive=None, reverse=None):
121127

122128

123129
def thresholded_relu(x, threshold=None):
130+
locals_var = locals().keys()
124131
kwargs = dict()
125-
for name in locals():
132+
for name in locals_var:
126133
val = locals()[name]
127134
if val is not None:
128135
kwargs[name] = val

python/paddle/fluid/layers/tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def create_global_var(shape,
111111
force_cpu=False,
112112
name=None):
113113
"""
114-
Create a new variable in the global block(block 0).
114+
Create a new tensor variable with value in the global block(block 0).
115115
116116
Args:
117117
shape(list[int]): shape of the variable

0 commit comments

Comments
 (0)