@@ -23,14 +23,14 @@ namespace operators {
23
23
template <typename T>
24
24
class CPUUniformRandomKernel : public framework ::OpKernel<T> {
25
25
public:
26
- void Compute (const framework::ExecutionContext& ctx) const override {
27
- framework::Tensor* tensor = nullptr ;
26
+ void Compute (const framework::ExecutionContext & ctx) const override {
27
+ framework::Tensor * tensor = nullptr ;
28
28
auto out_var = ctx.OutputVar (" Out" );
29
29
if (out_var->IsType <framework::LoDTensor>()) {
30
30
tensor = out_var->GetMutable <framework::LoDTensor>();
31
31
} else if (out_var->IsType <framework::SelectedRows>()) {
32
32
auto shape = ctx.Attr <std::vector<int >>(" shape" );
33
- auto * selected_rows = out_var->GetMutable <framework::SelectedRows>();
33
+ auto * selected_rows = out_var->GetMutable <framework::SelectedRows>();
34
34
tensor = selected_rows->mutable_value ();
35
35
tensor->Resize (framework::make_ddim (shape));
36
36
selected_rows->mutable_rows ()->reserve (shape[0 ]);
@@ -39,7 +39,7 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> {
39
39
" uniform_random_op's output only"
40
40
" supports SelectedRows and LoDTensor" );
41
41
}
42
- T* data = tensor->mutable_data <T>(ctx.GetPlace ());
42
+ T * data = tensor->mutable_data <T>(ctx.GetPlace ());
43
43
unsigned int seed = static_cast <unsigned int >(ctx.Attr <int >(" seed" ));
44
44
std::minstd_rand engine;
45
45
if (seed == 0 ) {
@@ -60,14 +60,14 @@ class UniformRandomOp : public framework::OperatorWithKernel {
60
60
public:
61
61
using framework::OperatorWithKernel::OperatorWithKernel;
62
62
63
- void InferShape (framework::InferShapeContext* ctx) const override {
63
+ void InferShape (framework::InferShapeContext * ctx) const override {
64
64
PADDLE_ENFORCE (ctx->HasOutput (" Out" ),
65
65
" Output(Out) of UniformRandomOp should not be null." );
66
66
67
67
PADDLE_ENFORCE (
68
68
ctx->Attrs ().Get <float >(" min" ) < ctx->Attrs ().Get <float >(" max" ),
69
69
" uniform_random's min must less then max" );
70
- auto & shape = ctx->Attrs ().Get <std::vector<int >>(" shape" );
70
+ auto & shape = ctx->Attrs ().Get <std::vector<int >>(" shape" );
71
71
std::vector<int64_t > temp;
72
72
temp.reserve (shape.size ());
73
73
for (auto dim : shape) {
@@ -78,7 +78,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
78
78
79
79
protected:
80
80
framework::OpKernelType GetExpectedKernelType (
81
- const framework::ExecutionContext& ctx) const override {
81
+ const framework::ExecutionContext & ctx) const override {
82
82
return framework::OpKernelType (
83
83
static_cast <framework::proto::VarType::Type>(ctx.Attr <int >(" dtype" )),
84
84
ctx.GetPlace ());
@@ -112,17 +112,17 @@ uniform distribution. The random result is in set [min, max].
112
112
113
113
class UniformRandomOpVarTypeInference : public framework ::VarTypeInference {
114
114
public:
115
- void operator ()(const framework::OpDesc& op_desc,
116
- framework::BlockDesc* block) const override {
115
+ void operator ()(const framework::OpDesc & op_desc,
116
+ framework::BlockDesc * block) const override {
117
117
auto out_var_name = op_desc.Output (" Out" ).front ();
118
- if (block->FindRecursiveOrCreateVar (out_var_name).GetType () ==
119
- framework::proto::VarType::SELECTED_ROWS) {
120
- block->FindRecursiveOrCreateVar (out_var_name)
121
- .SetType (framework::proto::VarType::SELECTED_ROWS);
122
- } else {
123
- block->FindRecursiveOrCreateVar (out_var_name)
124
- .SetType (framework::proto::VarType::LOD_TENSOR);
118
+ auto var_data_type = static_cast <framework::proto::VarType::Type>(
119
+ boost::get<int >(op_desc.GetAttr (" dtype" )));
120
+
121
+ auto out_var = block->FindRecursiveOrCreateVar (out_var_name);
122
+ if (out_var.GetType () != framework::proto::VarType::SELECTED_ROWS) {
123
+ out_var.SetType (framework::proto::VarType::LOD_TENSOR);
125
124
}
125
+ out_var.SetDataType (var_data_type);
126
126
}
127
127
};
128
128
0 commit comments