Skip to content

Commit 2c51410

Browse files
author
chengduo
authored
fix layers.uniform_random (#13859)
test=release/1.0.0
1 parent cddff20 commit 2c51410

File tree

2 files changed

+28
-21
lines changed

2 files changed

+28
-21
lines changed

paddle/fluid/operators/uniform_random_op.cc

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@ namespace operators {
2323
template <typename T>
2424
class CPUUniformRandomKernel : public framework::OpKernel<T> {
2525
public:
26-
void Compute(const framework::ExecutionContext& ctx) const override {
27-
framework::Tensor* tensor = nullptr;
26+
void Compute(const framework::ExecutionContext &ctx) const override {
27+
framework::Tensor *tensor = nullptr;
2828
auto out_var = ctx.OutputVar("Out");
2929
if (out_var->IsType<framework::LoDTensor>()) {
3030
tensor = out_var->GetMutable<framework::LoDTensor>();
3131
} else if (out_var->IsType<framework::SelectedRows>()) {
3232
auto shape = ctx.Attr<std::vector<int>>("shape");
33-
auto* selected_rows = out_var->GetMutable<framework::SelectedRows>();
33+
auto *selected_rows = out_var->GetMutable<framework::SelectedRows>();
3434
tensor = selected_rows->mutable_value();
3535
tensor->Resize(framework::make_ddim(shape));
3636
selected_rows->mutable_rows()->reserve(shape[0]);
@@ -39,7 +39,7 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> {
3939
"uniform_random_op's output only"
4040
"supports SelectedRows and LoDTensor");
4141
}
42-
T* data = tensor->mutable_data<T>(ctx.GetPlace());
42+
T *data = tensor->mutable_data<T>(ctx.GetPlace());
4343
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
4444
std::minstd_rand engine;
4545
if (seed == 0) {
@@ -60,14 +60,14 @@ class UniformRandomOp : public framework::OperatorWithKernel {
6060
public:
6161
using framework::OperatorWithKernel::OperatorWithKernel;
6262

63-
void InferShape(framework::InferShapeContext* ctx) const override {
63+
void InferShape(framework::InferShapeContext *ctx) const override {
6464
PADDLE_ENFORCE(ctx->HasOutput("Out"),
6565
"Output(Out) of UniformRandomOp should not be null.");
6666

6767
PADDLE_ENFORCE(
6868
ctx->Attrs().Get<float>("min") < ctx->Attrs().Get<float>("max"),
6969
"uniform_random's min must less then max");
70-
auto& shape = ctx->Attrs().Get<std::vector<int>>("shape");
70+
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape");
7171
std::vector<int64_t> temp;
7272
temp.reserve(shape.size());
7373
for (auto dim : shape) {
@@ -78,7 +78,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
7878

7979
protected:
8080
framework::OpKernelType GetExpectedKernelType(
81-
const framework::ExecutionContext& ctx) const override {
81+
const framework::ExecutionContext &ctx) const override {
8282
return framework::OpKernelType(
8383
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
8484
ctx.GetPlace());
@@ -112,17 +112,17 @@ uniform distribution. The random result is in set [min, max].
112112

113113
class UniformRandomOpVarTypeInference : public framework::VarTypeInference {
114114
public:
115-
void operator()(const framework::OpDesc& op_desc,
116-
framework::BlockDesc* block) const override {
115+
void operator()(const framework::OpDesc &op_desc,
116+
framework::BlockDesc *block) const override {
117117
auto out_var_name = op_desc.Output("Out").front();
118-
if (block->FindRecursiveOrCreateVar(out_var_name).GetType() ==
119-
framework::proto::VarType::SELECTED_ROWS) {
120-
block->FindRecursiveOrCreateVar(out_var_name)
121-
.SetType(framework::proto::VarType::SELECTED_ROWS);
122-
} else {
123-
block->FindRecursiveOrCreateVar(out_var_name)
124-
.SetType(framework::proto::VarType::LOD_TENSOR);
118+
auto var_data_type = static_cast<framework::proto::VarType::Type>(
119+
boost::get<int>(op_desc.GetAttr("dtype")));
120+
121+
auto out_var = block->FindRecursiveOrCreateVar(out_var_name);
122+
if (out_var.GetType() != framework::proto::VarType::SELECTED_ROWS) {
123+
out_var.SetType(framework::proto::VarType::LOD_TENSOR);
125124
}
125+
out_var.SetDataType(var_data_type);
126126
}
127127
};
128128

python/paddle/fluid/layers/ops.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from __future__ import print_function
1616
from .layer_function_generator import generate_layer_fn, generate_layer_fn_noattr
17+
from .. import core
18+
from ..framework import convert_np_dtype_to_dtype_
1719

1820
__activations_noattr__ = [
1921
'sigmoid',
@@ -58,8 +60,11 @@
5860

5961

6062
def uniform_random(shape, dtype=None, min=None, max=None, seed=None):
63+
locals_var = locals().keys()
64+
if not isinstance(dtype, core.VarDesc.VarType):
65+
dtype = convert_np_dtype_to_dtype_(dtype)
6166
kwargs = dict()
62-
for name in locals():
67+
for name in locals_var:
6368
val = locals()[name]
6469
if val is not None:
6570
kwargs[name] = val
@@ -78,8 +83,9 @@ def uniform_random(shape, dtype=None, min=None, max=None, seed=None):
7883

7984

8085
def hard_shrink(x, threshold=None):
86+
locals_var = locals().keys()
8187
kwargs = dict()
82-
for name in locals():
88+
for name in locals_var:
8389
val = locals()[name]
8490
if val is not None:
8591
kwargs[name] = val
@@ -99,12 +105,12 @@ def hard_shrink(x, threshold=None):
99105

100106

101107
def cumsum(x, axis=None, exclusive=None, reverse=None):
108+
locals_var = locals().keys()
102109
kwargs = dict()
103-
for name in locals():
110+
for name in locals_var:
104111
val = locals()[name]
105112
if val is not None:
106113
kwargs[name] = val
107-
108114
return _cum_sum_(**kwargs)
109115

110116

@@ -121,8 +127,9 @@ def cumsum(x, axis=None, exclusive=None, reverse=None):
121127

122128

123129
def thresholded_relu(x, threshold=None):
130+
locals_var = locals().keys()
124131
kwargs = dict()
125-
for name in locals():
132+
for name in locals_var:
126133
val = locals()[name]
127134
if val is not None:
128135
kwargs[name] = val

0 commit comments

Comments
 (0)