Skip to content

Commit c7f1f3e

Browse files
authored
Merge pull request #16214 from velconia/imperative_infer_var_type
Implement imperative infer var type
2 parents f8df9eb + 565b19b commit c7f1f3e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+575
-345
lines changed

paddle/fluid/framework/details/graph_test_base.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,11 @@ class SplitOpMaker : public OpProtoAndCheckerMaker {
6868

6969
class DummyVarTypeInference : public VarTypeInference {
7070
public:
71-
void operator()(const OpDesc& op_desc, BlockDesc* block) const override {
72-
auto& inputs = op_desc.Input("X");
73-
auto type = block->Var(inputs.front())->GetType();
74-
auto out_var_name = op_desc.Output("Out").front();
75-
block->Var(out_var_name)->SetType(type);
71+
void operator()(framework::InferVarTypeContext* ctx) const override {
72+
auto& inputs = ctx->Input("X");
73+
auto type = ctx->GetType(inputs.front());
74+
auto out_var_name = ctx->Output("Out").front();
75+
ctx->SetType(out_var_name, type);
7676
}
7777
};
7878

paddle/fluid/framework/details/op_registry.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ limitations under the License. */
1616

1717
#include <string>
1818
#include <tuple>
19+
#include <unordered_map>
20+
#include <unordered_set>
1921
#include <vector>
2022
#include "paddle/fluid/framework/grad_op_desc_maker.h"
2123
#include "paddle/fluid/framework/inplace_op_inference.h"
@@ -127,9 +129,9 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
127129
template <typename T>
128130
struct OpInfoFiller<T, kVarTypeInference> {
129131
void operator()(const char* op_type, OpInfo* info) const {
130-
info->infer_var_type_ = [](const OpDesc& fwd_op, BlockDesc* block) {
132+
info->infer_var_type_ = [](InferVarTypeContext* context) {
131133
T inference;
132-
inference(fwd_op, block);
134+
inference(context);
133135
};
134136
}
135137
};

paddle/fluid/framework/ir/graph_test.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,20 +43,20 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
4343

4444
class SumOpVarTypeInference : public VarTypeInference {
4545
public:
46-
void operator()(const OpDesc &op_desc, BlockDesc *block) const override {
47-
auto &inputs = op_desc.Input("X");
46+
void operator()(InferVarTypeContext *ctx) const override {
47+
auto &inputs = ctx->Input("X");
4848
auto default_var_type = proto::VarType::SELECTED_ROWS;
4949

5050
bool any_input_is_lod_tensor = std::any_of(
51-
inputs.begin(), inputs.end(), [block](const std::string &name) {
52-
return block->Var(name)->GetType() == proto::VarType::LOD_TENSOR;
51+
inputs.begin(), inputs.end(), [&ctx](const std::string &name) {
52+
return ctx->GetType(name) == proto::VarType::LOD_TENSOR;
5353
});
5454
if (any_input_is_lod_tensor) {
5555
default_var_type = proto::VarType::LOD_TENSOR;
5656
}
5757

58-
auto out_var_name = op_desc.Output("Out").front();
59-
block->Var(out_var_name)->SetType(default_var_type);
58+
auto out_var_name = ctx->Output("Out").front();
59+
ctx->SetType(out_var_name, default_var_type);
6060
}
6161
};
6262

@@ -71,7 +71,7 @@ class DummyOpMaker : public OpProtoAndCheckerMaker {
7171

7272
class DummyOpVarTypeInference : public VarTypeInference {
7373
public:
74-
void operator()(const OpDesc &op_desc, BlockDesc *block) const override {}
74+
void operator()(framework::InferVarTypeContext *ctx) const override {}
7575
};
7676
} // namespace framework
7777
} // namespace paddle

paddle/fluid/framework/op_desc.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License. */
2424
#include "paddle/fluid/framework/operator.h"
2525
#include "paddle/fluid/framework/program_desc.h"
2626
#include "paddle/fluid/framework/shape_inference.h"
27+
#include "paddle/fluid/framework/var_type_inference.h"
2728

2829
namespace paddle {
2930
namespace framework {
@@ -677,7 +678,8 @@ void OpDesc::InferVarType(BlockDesc *block) const {
677678
// var type inference. Hence, we don't do any "default" setting here.
678679
auto &info = OpInfoMap::Instance().Get(this->Type());
679680
if (info.infer_var_type_) {
680-
info.infer_var_type_(*this, block);
681+
InferVarTypeContext context(this, block);
682+
info.infer_var_type_(&context);
681683
}
682684
}
683685

paddle/fluid/framework/type_defs.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ namespace framework {
2727
class OperatorBase;
2828
class OpDesc;
2929
class InferShapeContext;
30+
class InferVarTypeContext;
3031
class BlockDesc;
3132
class Variable;
3233

@@ -53,7 +54,7 @@ using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDesc>>(
5354
const std::vector<BlockDesc*>& grad_block)>;
5455

5556
using InferVarTypeFN =
56-
std::function<void(const OpDesc& /*op_desc*/, BlockDesc* /*block*/)>;
57+
std::function<void(framework::InferVarTypeContext* /*context*/)>;
5758

5859
using InferShapeFN = std::function<void(InferShapeContext*)>;
5960

paddle/fluid/framework/var_type_inference.h

Lines changed: 108 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,33 +14,132 @@ limitations under the License. */
1414

1515
#pragma once
1616
#include <string>
17+
#include <unordered_map>
18+
#include <vector>
1719
#include "paddle/fluid/framework/block_desc.h"
1820
#include "paddle/fluid/framework/op_desc.h"
1921
#include "paddle/fluid/framework/type_defs.h"
2022

2123
namespace paddle {
2224
namespace framework {
2325

26+
class OpDesc;
27+
class BlockDesc;
28+
// default infer var type context
29+
class InferVarTypeContext {
30+
public:
31+
InferVarTypeContext(const OpDesc* op, BlockDesc* block)
32+
: op_(op), block_(block) {}
33+
34+
virtual ~InferVarTypeContext() {}
35+
36+
virtual Attribute GetAttr(const std::string& name) const {
37+
PADDLE_ENFORCE_NOT_NULL(op_);
38+
return op_->GetAttr(name);
39+
}
40+
41+
virtual bool HasVar(const std::string& name) const {
42+
PADDLE_ENFORCE_NOT_NULL(block_);
43+
return block_->FindVarRecursive(name) != nullptr;
44+
}
45+
46+
virtual bool HasInput(const std::string& name) const {
47+
PADDLE_ENFORCE_NOT_NULL(op_);
48+
return op_->Inputs().count(name) > 0;
49+
}
50+
51+
virtual bool HasOutput(const std::string& name) const {
52+
PADDLE_ENFORCE_NOT_NULL(op_);
53+
return op_->Outputs().count(name) > 0;
54+
}
55+
56+
virtual const std::vector<std::string>& Input(const std::string& name) const {
57+
PADDLE_ENFORCE_NOT_NULL(op_);
58+
return op_->Input(name);
59+
}
60+
61+
virtual const std::vector<std::string>& Output(
62+
const std::string& name) const {
63+
PADDLE_ENFORCE_NOT_NULL(op_);
64+
return op_->Output(name);
65+
}
66+
67+
virtual proto::VarType::Type GetType(const std::string& name) const {
68+
PADDLE_ENFORCE_NOT_NULL(block_);
69+
return block_->FindRecursiveOrCreateVar(name).GetType();
70+
}
71+
72+
virtual void SetType(const std::string& name, proto::VarType::Type type) {
73+
PADDLE_ENFORCE_NOT_NULL(block_);
74+
block_->FindRecursiveOrCreateVar(name).SetType(type);
75+
}
76+
77+
virtual proto::VarType::Type GetDataType(const std::string& name) const {
78+
PADDLE_ENFORCE_NOT_NULL(block_);
79+
return block_->FindRecursiveOrCreateVar(name).GetDataType();
80+
}
81+
82+
virtual void SetDataType(const std::string& name, proto::VarType::Type type) {
83+
PADDLE_ENFORCE_NOT_NULL(block_);
84+
block_->FindRecursiveOrCreateVar(name).SetDataType(type);
85+
}
86+
87+
virtual std::vector<proto::VarType::Type> GetDataTypes(
88+
const std::string& name) const {
89+
PADDLE_ENFORCE_NOT_NULL(block_);
90+
return block_->FindRecursiveOrCreateVar(name).GetDataTypes();
91+
}
92+
93+
virtual void SetDataTypes(
94+
const std::string& name,
95+
const std::vector<proto::VarType::Type>& multiple_data_type) {
96+
PADDLE_ENFORCE_NOT_NULL(block_);
97+
block_->FindRecursiveOrCreateVar(name).SetDataTypes(multiple_data_type);
98+
}
99+
100+
virtual std::vector<int64_t> GetShape(const std::string& name) const {
101+
PADDLE_ENFORCE_NOT_NULL(block_);
102+
return block_->FindRecursiveOrCreateVar(name).GetShape();
103+
}
104+
105+
virtual void SetShape(const std::string& name,
106+
const std::vector<int64_t>& dims) {
107+
PADDLE_ENFORCE_NOT_NULL(block_);
108+
block_->FindRecursiveOrCreateVar(name).SetShape(dims);
109+
}
110+
111+
virtual int32_t GetLoDLevel(const std::string& name) const {
112+
PADDLE_ENFORCE_NOT_NULL(block_);
113+
return block_->FindRecursiveOrCreateVar(name).GetLoDLevel();
114+
}
115+
116+
virtual void SetLoDLevel(const std::string& name, int32_t lod_level) {
117+
PADDLE_ENFORCE_NOT_NULL(block_);
118+
block_->FindRecursiveOrCreateVar(name).SetLoDLevel(lod_level);
119+
}
120+
121+
protected:
122+
const OpDesc* op_;
123+
BlockDesc* block_;
124+
};
125+
24126
class VarTypeInference {
25127
public:
26128
virtual ~VarTypeInference() {}
27-
virtual void operator()(const OpDesc& op_desc, BlockDesc* block) const = 0;
129+
virtual void operator()(InferVarTypeContext* context) const = 0; // NOLINT
28130
};
29131

30132
class PassInDtypeAndVarTypeToOutput : public framework::VarTypeInference {
31133
public:
32-
void operator()(const framework::OpDesc& op_desc,
33-
framework::BlockDesc* block) const final {
134+
void operator()(framework::InferVarTypeContext* ctx) const final { // NOLINT
34135
auto in_out_var_names = this->GetInputOutputWithSameType();
35136

36137
for (auto& i_o_n : in_out_var_names) {
37-
auto& x_name = op_desc.Input(i_o_n.first).at(0);
38-
auto& out_name = op_desc.Output(i_o_n.second).at(0);
138+
auto& x_name = ctx->Input(i_o_n.first).at(0);
139+
auto& out_name = ctx->Output(i_o_n.second).at(0);
39140

40-
auto& x = block->FindRecursiveOrCreateVar(x_name);
41-
auto& out = block->FindRecursiveOrCreateVar(out_name);
42-
out.SetType(x.GetType());
43-
out.SetDataType(x.GetDataType());
141+
ctx->SetType(out_name, ctx->GetType(x_name));
142+
ctx->SetDataType(out_name, ctx->GetDataType(x_name));
44143
}
45144
}
46145

paddle/fluid/framework/var_type_inference_test.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,20 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
4444

4545
class SumOpVarTypeInference : public VarTypeInference {
4646
public:
47-
void operator()(const OpDesc &op_desc, BlockDesc *block) const override {
48-
auto &inputs = op_desc.Input("X");
47+
void operator()(framework::InferVarTypeContext *ctx) const override {
48+
auto &inputs = ctx->Input("X");
4949
auto default_var_type = proto::VarType::SELECTED_ROWS;
5050

5151
bool any_input_is_lod_tensor = std::any_of(
52-
inputs.begin(), inputs.end(), [block](const std::string &name) {
53-
return block->Var(name)->GetType() == proto::VarType::LOD_TENSOR;
52+
inputs.begin(), inputs.end(), [&ctx](const std::string &name) {
53+
return ctx->GetType(name) == proto::VarType::LOD_TENSOR;
5454
});
5555
if (any_input_is_lod_tensor) {
5656
default_var_type = proto::VarType::LOD_TENSOR;
5757
}
5858

59-
auto out_var_name = op_desc.Output("Out").front();
60-
block->Var(out_var_name)->SetType(default_var_type);
59+
auto out_var_name = ctx->Output("Out").front();
60+
ctx->SetType(out_var_name, default_var_type);
6161
}
6262
};
6363
} // namespace framework

0 commit comments

Comments
 (0)