Skip to content

Commit fcadb45

Browse files
abhinavarorawangkuiyi
authored andcommitted
Separate VarType from VarDesc in framework.proto and fix all related compiler errors (#8414)
* Refine Type system * Fixing type inference * Fixed create_reader_op.cc * Fix var_desc.h * Fixed executor.cc * Fix shape_inference.h * Fixed create_reader_op.cc * Fix tensor_util.h * Fixed var_type_inference_test.cc * Fix shape_inference.cc * Fixed sum_op.c * Fixed read_op.cc * Fix var_type.h * Fixed beam_search_decode_op.cc * sendrecvop_utils.cc * Fix operator.cc * Fixed lookup_table_op.cc * Fixed op_desc.cc * Fixed get_places_op.cc * Fixed lod_rank_table_op.cc * Fixed beam_search_op.cc * Fix var_desc.cc * Fixed lod_tensor_to_array_op.cc * Fixed while_op.cc * Fix program_desc_test.cc * tensor_array_read_write_op.cc * Fix assign_op.cc * Fix executor.cc * Fix protobuf.cc * Fix protobuf.cc
1 parent f82fa64 commit fcadb45

26 files changed

+198
-184
lines changed

paddle/fluid/framework/executor.cc

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,24 +36,24 @@ namespace framework {
3636

3737
Executor::Executor(const platform::Place& place) : place_(place) {}
3838

39-
static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
40-
if (var_type == proto::VarDesc::LOD_TENSOR) {
39+
static void CreateTensor(Variable* var, proto::VarType::Type var_type) {
40+
if (var_type == proto::VarType::LOD_TENSOR) {
4141
var->GetMutable<LoDTensor>();
42-
} else if (var_type == proto::VarDesc::SELECTED_ROWS) {
42+
} else if (var_type == proto::VarType::SELECTED_ROWS) {
4343
var->GetMutable<SelectedRows>();
44-
} else if (var_type == proto::VarDesc::FEED_MINIBATCH) {
44+
} else if (var_type == proto::VarType::FEED_MINIBATCH) {
4545
var->GetMutable<FeedFetchList>();
46-
} else if (var_type == proto::VarDesc::FETCH_LIST) {
46+
} else if (var_type == proto::VarType::FETCH_LIST) {
4747
var->GetMutable<FeedFetchList>();
48-
} else if (var_type == proto::VarDesc::STEP_SCOPES) {
48+
} else if (var_type == proto::VarType::STEP_SCOPES) {
4949
var->GetMutable<std::vector<framework::Scope>>();
50-
} else if (var_type == proto::VarDesc::LOD_RANK_TABLE) {
50+
} else if (var_type == proto::VarType::LOD_RANK_TABLE) {
5151
var->GetMutable<LoDRankTable>();
52-
} else if (var_type == proto::VarDesc::LOD_TENSOR_ARRAY) {
52+
} else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {
5353
var->GetMutable<LoDTensorArray>();
54-
} else if (var_type == proto::VarDesc::PLACE_LIST) {
54+
} else if (var_type == proto::VarType::PLACE_LIST) {
5555
var->GetMutable<platform::PlaceList>();
56-
} else if (var_type == proto::VarDesc::READER) {
56+
} else if (var_type == proto::VarType::READER) {
5757
var->GetMutable<ReaderHolder>();
5858
} else {
5959
PADDLE_THROW(
@@ -182,7 +182,7 @@ static bool has_feed_operators(
182182
auto var = block->FindVar(feed_holder_name);
183183
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
184184
feed_holder_name);
185-
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarDesc::FEED_MINIBATCH,
185+
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH,
186186
"'%s' variable should be 'FEED_MINIBATCH' type",
187187
feed_holder_name);
188188
}
@@ -222,7 +222,7 @@ static bool has_fetch_operators(
222222
auto var = block->FindVar(fetch_holder_name);
223223
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
224224
fetch_holder_name);
225-
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarDesc::FETCH_LIST,
225+
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST,
226226
"'%s' variable should be 'FETCH_LIST' type",
227227
fetch_holder_name);
228228
}
@@ -241,7 +241,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
241241
if (!has_feed_operators(global_block, feed_targets, feed_holder_name)) {
242242
// create feed_holder variable
243243
auto* feed_holder = global_block->Var(feed_holder_name);
244-
feed_holder->SetType(proto::VarDesc::FEED_MINIBATCH);
244+
feed_holder->SetType(proto::VarType::FEED_MINIBATCH);
245245
feed_holder->SetPersistable(true);
246246

247247
int i = 0;
@@ -274,7 +274,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
274274
if (!has_fetch_operators(global_block, fetch_targets, fetch_holder_name)) {
275275
// create fetch_holder variable
276276
auto* fetch_holder = global_block->Var(fetch_holder_name);
277-
fetch_holder->SetType(proto::VarDesc::FETCH_LIST);
277+
fetch_holder->SetType(proto::VarType::FETCH_LIST);
278278
fetch_holder->SetPersistable(true);
279279

280280
int i = 0;

paddle/fluid/framework/framework.proto

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -101,25 +101,8 @@ enum DataType {
101101
FP64 = 6;
102102
}
103103

104-
message TensorDesc {
105-
required DataType data_type = 1;
106-
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
107-
}
108-
109-
message LoDTensorDesc {
110-
required TensorDesc tensor = 1;
111-
optional int32 lod_level = 2 [ default = 0 ];
112-
}
113-
114-
message LoDTensorArrayDesc {
115-
required TensorDesc tensor = 1;
116-
optional int32 lod_level = 2 [ default = 0 ];
117-
}
118-
119-
message ReaderDesc { repeated LoDTensorDesc lod_tensor = 1; }
120-
121-
message VarDesc {
122-
enum VarType {
104+
message VarType {
105+
enum Type {
123106
LOD_TENSOR = 1;
124107
SELECTED_ROWS = 2;
125108
FEED_MINIBATCH = 3;
@@ -130,13 +113,35 @@ message VarDesc {
130113
PLACE_LIST = 8;
131114
READER = 9;
132115
}
116+
117+
required Type type = 1;
118+
119+
message TensorDesc {
120+
required DataType data_type = 1;
121+
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
122+
}
123+
optional TensorDesc selected_rows = 2;
124+
125+
message LoDTensorDesc {
126+
required TensorDesc tensor = 1;
127+
optional int32 lod_level = 2 [ default = 0 ];
128+
}
129+
optional LoDTensorDesc lod_tensor = 3;
130+
131+
message LoDTensorArrayDesc {
132+
required TensorDesc tensor = 1;
133+
optional int32 lod_level = 2 [ default = 0 ];
134+
}
135+
optional LoDTensorArrayDesc tensor_array = 4;
136+
137+
message ReaderDesc { repeated LoDTensorDesc lod_tensor = 1; }
138+
optional ReaderDesc reader = 5;
139+
}
140+
141+
message VarDesc {
133142
required string name = 1;
134143
required VarType type = 2;
135144
optional bool persistable = 3 [ default = false ];
136-
optional LoDTensorDesc lod_tensor = 4;
137-
optional TensorDesc selected_rows = 5;
138-
optional LoDTensorArrayDesc tensor_array = 6;
139-
optional ReaderDesc reader = 7;
140145
}
141146

142147
message BlockDesc {

paddle/fluid/framework/op_desc.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,11 @@ class CompileTimeInferShapeContext : public InferShapeContext {
5353
PADDLE_ENFORCE_LT(j, Outputs(out).size());
5454
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
5555
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
56-
if (in_var->GetType() != proto::VarDesc::LOD_TENSOR) {
56+
if (in_var->GetType() != proto::VarType::LOD_TENSOR) {
5757
VLOG(3) << "input " << in << " is not LodTensor";
5858
return;
5959
}
60-
PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarDesc::LOD_TENSOR,
60+
PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarType::LOD_TENSOR,
6161
"The %d-th output of Output(%s) must be LoDTensor.", j,
6262
out);
6363
out_var->SetLoDLevel(in_var->GetLoDLevel());
@@ -66,7 +66,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
6666
bool IsRuntime() const override;
6767

6868
protected:
69-
proto::VarDesc::VarType GetVarType(const std::string &name) const override;
69+
proto::VarType::Type GetVarType(const std::string &name) const override;
7070

7171
DDim GetDim(const std::string &name) const override;
7272

@@ -388,7 +388,7 @@ void OpDesc::InferVarType(BlockDesc *block) const {
388388
for (auto &out_pair : this->outputs_) {
389389
for (auto &out_var_name : out_pair.second) {
390390
block->FindRecursiveOrCreateVar(out_var_name)
391-
.SetType(proto::VarDesc::LOD_TENSOR);
391+
.SetType(proto::VarType::LOD_TENSOR);
392392
}
393393
}
394394
}
@@ -507,7 +507,7 @@ void CompileTimeInferShapeContext::SetRepeatedDims(
507507

508508
bool CompileTimeInferShapeContext::IsRuntime() const { return false; }
509509

510-
proto::VarDesc::VarType CompileTimeInferShapeContext::GetVarType(
510+
proto::VarType::Type CompileTimeInferShapeContext::GetVarType(
511511
const std::string &name) const {
512512
return block_.FindVarRecursive(name)->GetType();
513513
}

paddle/fluid/framework/operator.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
477477
}
478478
}
479479

480-
proto::VarDesc::VarType GetVarType(const std::string& name) const override {
480+
proto::VarType::Type GetVarType(const std::string& name) const override {
481481
auto* var = scope_.FindVar(name);
482482
return ToVarType(var->Type());
483483
}

paddle/fluid/framework/program_desc_test.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ TEST(ProgramDesc, copy_ctor) {
2222
ProgramDesc program;
2323
auto* global_block = program.MutableBlock(0);
2424
auto* x = global_block->Var("X");
25-
x->SetType(proto::VarDesc_VarType_LOD_TENSOR);
25+
x->SetType(proto::VarType::LOD_TENSOR);
2626
x->SetLoDLevel(0);
2727
x->SetDataType(proto::FP32);
2828
x->SetShape({1000, 784});
2929

3030
auto* y = global_block->Var("Y");
31-
y->SetType(proto::VarDesc_VarType_LOD_TENSOR);
31+
y->SetType(proto::VarType::LOD_TENSOR);
3232
y->SetLoDLevel(0);
3333
y->SetDataType(proto::FP32);
3434
y->SetShape({784, 100});
@@ -39,7 +39,7 @@ TEST(ProgramDesc, copy_ctor) {
3939
op->SetInput("Y", {y->Name()});
4040

4141
auto* out = global_block->Var("Out");
42-
out->SetType(proto::VarDesc_VarType_LOD_TENSOR);
42+
out->SetType(proto::VarType::LOD_TENSOR);
4343
op->SetOutput("Y", {out->Name()});
4444

4545
ProgramDesc program_copy(program);
@@ -84,13 +84,13 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
8484
ProgramDesc program_origin;
8585
auto* global_block = program_origin.MutableBlock(0);
8686
auto* x = global_block->Var("X");
87-
x->SetType(proto::VarDesc_VarType_LOD_TENSOR);
87+
x->SetType(proto::VarType::LOD_TENSOR);
8888
x->SetLoDLevel(0);
8989
x->SetDataType(proto::FP32);
9090
x->SetShape({1000, 784});
9191

9292
auto* y = global_block->Var("Y");
93-
y->SetType(proto::VarDesc_VarType_LOD_TENSOR);
93+
y->SetType(proto::VarType::LOD_TENSOR);
9494
y->SetLoDLevel(0);
9595
y->SetDataType(proto::FP32);
9696
y->SetShape({784, 100});
@@ -101,7 +101,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
101101
op->SetInput("Y", {y->Name()});
102102

103103
auto* out = global_block->Var("Out");
104-
out->SetType(proto::VarDesc_VarType_LOD_TENSOR);
104+
out->SetType(proto::VarType::LOD_TENSOR);
105105
op->SetOutput("Y", {out->Name()});
106106

107107
std::string binary_str;

paddle/fluid/framework/shape_inference.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,19 +116,19 @@ void InferShapeContext::SetDims(const std::vector<std::string> &names,
116116
}
117117
}
118118

119-
std::vector<proto::VarDesc::VarType> InferShapeContext::GetInputsVarType(
119+
std::vector<proto::VarType::Type> InferShapeContext::GetInputsVarType(
120120
const std::string &name) const {
121121
return GetVarTypes(Inputs(name));
122122
}
123123

124-
std::vector<proto::VarDesc::VarType> InferShapeContext::GetOutputsVarType(
124+
std::vector<proto::VarType::Type> InferShapeContext::GetOutputsVarType(
125125
const std::string &name) const {
126126
return GetVarTypes(Outputs(name));
127127
}
128128

129-
std::vector<proto::VarDesc::VarType> InferShapeContext::GetVarTypes(
129+
std::vector<proto::VarType::Type> InferShapeContext::GetVarTypes(
130130
const std::vector<std::string> &names) const {
131-
std::vector<proto::VarDesc::VarType> retv;
131+
std::vector<proto::VarType::Type> retv;
132132
retv.resize(names.size());
133133
std::transform(names.begin(), names.end(), retv.begin(),
134134
std::bind(std::mem_fn(&InferShapeContext::GetVarType), this,

paddle/fluid/framework/shape_inference.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ class InferShapeContext {
3131
virtual bool HasInput(const std::string &name) const = 0;
3232
virtual bool HasOutput(const std::string &name) const = 0;
3333

34-
std::vector<proto::VarDesc::VarType> GetInputsVarType(
34+
std::vector<proto::VarType::Type> GetInputsVarType(
3535
const std::string &name) const;
36-
std::vector<proto::VarDesc::VarType> GetOutputsVarType(
36+
std::vector<proto::VarType::Type> GetOutputsVarType(
3737
const std::string &name) const;
3838

3939
virtual bool HasInputs(const std::string &name) const = 0;
@@ -75,10 +75,10 @@ class InferShapeContext {
7575

7676
std::vector<DDim> GetDims(const std::vector<std::string> &names) const;
7777

78-
std::vector<proto::VarDesc::VarType> GetVarTypes(
78+
std::vector<proto::VarType::Type> GetVarTypes(
7979
const std::vector<std::string> &names) const;
8080

81-
virtual proto::VarDesc::VarType GetVarType(const std::string &name) const = 0;
81+
virtual proto::VarType::Type GetVarType(const std::string &name) const = 0;
8282

8383
virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0;
8484
};

paddle/fluid/framework/tensor_util.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ inline void SerializeToStream(std::ostream& os, const Tensor& tensor,
225225
{ // the 2nd field, tensor description
226226
// int32_t size
227227
// void* protobuf message
228-
proto::TensorDesc desc;
228+
proto::VarType::TensorDesc desc;
229229
desc.set_data_type(framework::ToDataType(tensor.type()));
230230
auto dims = framework::vectorize(tensor.dims());
231231
auto* pb_dims = desc.mutable_dims();
@@ -290,7 +290,7 @@ inline void DeserializeFromStream(std::istream& is, Tensor* tensor,
290290
uint32_t version;
291291
is.read(reinterpret_cast<char*>(&version), sizeof(version));
292292
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
293-
proto::TensorDesc desc;
293+
proto::VarType::TensorDesc desc;
294294
{ // int32_t size
295295
// proto buffer
296296
int32_t size;

0 commit comments

Comments
 (0)