Skip to content

Commit b40e41f

Browse files
committed
Polish code style
test=develop
1 parent 36dce65 commit b40e41f

40 files changed

+192
-192
lines changed

paddle/fluid/framework/details/graph_test_base.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,11 @@ class SplitOpMaker : public OpProtoAndCheckerMaker {
6868

6969
class DummyVarTypeInference : public VarTypeInference {
7070
public:
71-
void operator()(framework::InferVarTypeContext& ctx) const override {
72-
auto& inputs = ctx.Input("X");
73-
auto type = ctx.GetType(inputs.front());
74-
auto out_var_name = ctx.Output("Out").front();
75-
ctx.SetType(out_var_name, type);
71+
void operator()(framework::InferVarTypeContext* ctx) const override {
72+
auto& inputs = ctx->Input("X");
73+
auto type = ctx->GetType(inputs.front());
74+
auto out_var_name = ctx->Output("Out").front();
75+
ctx->SetType(out_var_name, type);
7676
}
7777
};
7878

paddle/fluid/framework/details/op_registry.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ struct OpInfoFiller<T, kVarTypeInference> {
131131
void operator()(const char* op_type, OpInfo* info) const {
132132
info->infer_var_type_ = [](InferVarTypeContext* context) {
133133
T inference;
134-
inference(*context);
134+
inference(context);
135135
};
136136
}
137137
};

paddle/fluid/framework/ir/graph_test.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,20 +43,20 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
4343

4444
class SumOpVarTypeInference : public VarTypeInference {
4545
public:
46-
void operator()(InferVarTypeContext &ctx) const override {
47-
auto &inputs = ctx.Input("X");
46+
void operator()(InferVarTypeContext *ctx) const override {
47+
auto &inputs = ctx->Input("X");
4848
auto default_var_type = proto::VarType::SELECTED_ROWS;
4949

5050
bool any_input_is_lod_tensor = std::any_of(
5151
inputs.begin(), inputs.end(), [&ctx](const std::string &name) {
52-
return ctx.GetType(name) == proto::VarType::LOD_TENSOR;
52+
return ctx->GetType(name) == proto::VarType::LOD_TENSOR;
5353
});
5454
if (any_input_is_lod_tensor) {
5555
default_var_type = proto::VarType::LOD_TENSOR;
5656
}
5757

58-
auto out_var_name = ctx.Output("Out").front();
59-
ctx.SetType(out_var_name, default_var_type);
58+
auto out_var_name = ctx->Output("Out").front();
59+
ctx->SetType(out_var_name, default_var_type);
6060
}
6161
};
6262

@@ -71,7 +71,7 @@ class DummyOpMaker : public OpProtoAndCheckerMaker {
7171

7272
class DummyOpVarTypeInference : public VarTypeInference {
7373
public:
74-
void operator()(framework::InferVarTypeContext &ctx) const override {}
74+
void operator()(framework::InferVarTypeContext *ctx) const override {}
7575
};
7676
} // namespace framework
7777
} // namespace paddle

paddle/fluid/framework/var_type_inference.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,20 +126,20 @@ class InferVarTypeContext {
126126
class VarTypeInference {
127127
public:
128128
virtual ~VarTypeInference() {}
129-
virtual void operator()(InferVarTypeContext& context) const = 0; // NOLINT
129+
virtual void operator()(InferVarTypeContext* context) const = 0; // NOLINT
130130
};
131131

132132
class PassInDtypeAndVarTypeToOutput : public framework::VarTypeInference {
133133
public:
134-
void operator()(framework::InferVarTypeContext& ctx) const final { // NOLINT
134+
void operator()(framework::InferVarTypeContext* ctx) const final { // NOLINT
135135
auto in_out_var_names = this->GetInputOutputWithSameType();
136136

137137
for (auto& i_o_n : in_out_var_names) {
138-
auto& x_name = ctx.Input(i_o_n.first).at(0);
139-
auto& out_name = ctx.Output(i_o_n.second).at(0);
138+
auto& x_name = ctx->Input(i_o_n.first).at(0);
139+
auto& out_name = ctx->Output(i_o_n.second).at(0);
140140

141-
ctx.SetType(out_name, ctx.GetType(x_name));
142-
ctx.SetDataType(out_name, ctx.GetDataType(x_name));
141+
ctx->SetType(out_name, ctx->GetType(x_name));
142+
ctx->SetDataType(out_name, ctx->GetDataType(x_name));
143143
}
144144
}
145145

paddle/fluid/framework/var_type_inference_test.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,20 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
4444

4545
class SumOpVarTypeInference : public VarTypeInference {
4646
public:
47-
void operator()(framework::InferVarTypeContext &ctx) const override {
48-
auto &inputs = ctx.Input("X");
47+
void operator()(framework::InferVarTypeContext *ctx) const override {
48+
auto &inputs = ctx->Input("X");
4949
auto default_var_type = proto::VarType::SELECTED_ROWS;
5050

5151
bool any_input_is_lod_tensor = std::any_of(
5252
inputs.begin(), inputs.end(), [&ctx](const std::string &name) {
53-
return ctx.GetType(name) == proto::VarType::LOD_TENSOR;
53+
return ctx->GetType(name) == proto::VarType::LOD_TENSOR;
5454
});
5555
if (any_input_is_lod_tensor) {
5656
default_var_type = proto::VarType::LOD_TENSOR;
5757
}
5858

59-
auto out_var_name = ctx.Output("Out").front();
60-
ctx.SetType(out_var_name, default_var_type);
59+
auto out_var_name = ctx->Output("Out").front();
60+
ctx->SetType(out_var_name, default_var_type);
6161
}
6262
};
6363
} // namespace framework

paddle/fluid/imperative/tracer.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {
161161
}
162162

163163
std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
164-
VarBasePtrMap& outputs,
164+
VarBasePtrMap* outputs,
165165
framework::AttributeMap attrs_map,
166166
const platform::Place expected_place,
167167
const bool stop_gradient) {
@@ -195,7 +195,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
195195
}
196196
}
197197

198-
op->output_vars_ = outputs;
198+
op->output_vars_ = *outputs;
199199
for (auto it : op->output_vars_) {
200200
auto& outvars = outvars_map[it.first];
201201
const std::vector<VarBase*>& outputs = it.second;
@@ -218,7 +218,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
218218
framework::VariableNameMap invars_name_map =
219219
CreateInputVarNameMap(op, inputs);
220220
framework::VariableNameMap outvars_name_map =
221-
CreateOutputVarNameMap(op, outputs);
221+
CreateOutputVarNameMap(op, *outputs);
222222

223223
auto& info = framework::OpInfoMap::Instance().Get(op->Type());
224224
if (info.Checker() != nullptr) {
@@ -230,8 +230,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
230230
outvars_name_map, attrs_map);
231231

232232
if (info.infer_var_type_) {
233-
RuntimeInferVarTypeContext infer_var_type_ctx(&inputs, &outputs,
234-
&attrs_map);
233+
RuntimeInferVarTypeContext infer_var_type_ctx(&inputs, outputs, &attrs_map);
235234
info.infer_var_type_(&infer_var_type_ctx);
236235
}
237236

paddle/fluid/imperative/tracer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class Tracer {
4848
virtual ~Tracer() {}
4949

5050
std::set<std::string> Trace(OpBase* op, const VarBasePtrMap& inputs,
51-
VarBasePtrMap& outputs, // NOLINT
51+
VarBasePtrMap* outputs, // NOLINT
5252
framework::AttributeMap attrs_map,
5353
const platform::Place expected_place,
5454
const bool stop_gradient = false);

paddle/fluid/operators/beam_search_decode_op.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,12 @@ class BeamSearchDecodeInferShape : public framework::InferShapeBase {
203203

204204
class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
205205
public:
206-
void operator()(framework::InferVarTypeContext& ctx) const override {
207-
for (auto& o : ctx.Output("SentenceIds")) {
208-
ctx.SetType(o, framework::proto::VarType::LOD_TENSOR);
206+
void operator()(framework::InferVarTypeContext* ctx) const override {
207+
for (auto& o : ctx->Output("SentenceIds")) {
208+
ctx->SetType(o, framework::proto::VarType::LOD_TENSOR);
209209
}
210-
for (auto& o : ctx.Output("SentenceScores")) {
211-
ctx.SetType(o, framework::proto::VarType::LOD_TENSOR);
210+
for (auto& o : ctx->Output("SentenceScores")) {
211+
ctx->SetType(o, framework::proto::VarType::LOD_TENSOR);
212212
}
213213
}
214214
};

paddle/fluid/operators/beam_search_op.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,12 @@ class BeamSearchOp : public framework::OperatorWithKernel {
120120

121121
class BeamSearchInferVarType : public framework::VarTypeInference {
122122
public:
123-
void operator()(framework::InferVarTypeContext &ctx) const override {
124-
for (auto &o : ctx.Output("selected_ids")) {
125-
ctx.SetType(o, framework::proto::VarType::LOD_TENSOR);
123+
void operator()(framework::InferVarTypeContext *ctx) const override {
124+
for (auto &o : ctx->Output("selected_ids")) {
125+
ctx->SetType(o, framework::proto::VarType::LOD_TENSOR);
126126
}
127-
for (auto &o : ctx.Output("selected_scores")) {
128-
ctx.SetType(o, framework::proto::VarType::LOD_TENSOR);
127+
for (auto &o : ctx->Output("selected_scores")) {
128+
ctx->SetType(o, framework::proto::VarType::LOD_TENSOR);
129129
}
130130
}
131131
};

paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,13 @@ class WriteToArrayInferShape : public framework::InferShapeBase {
100100

101101
class WriteToArrayInferVarType : public framework::VarTypeInference {
102102
public:
103-
void operator()(framework::InferVarTypeContext &ctx) const override {
104-
auto x_name = ctx.Input("X")[0];
105-
auto out_name = ctx.Output("Out")[0];
103+
void operator()(framework::InferVarTypeContext *ctx) const override {
104+
auto x_name = ctx->Input("X")[0];
105+
auto out_name = ctx->Output("Out")[0];
106106
VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY";
107-
ctx.SetType(out_name, framework::proto::VarType::LOD_TENSOR_ARRAY);
108-
if (ctx.HasVar(x_name)) {
109-
ctx.SetDataType(out_name, ctx.GetDataType(x_name));
107+
ctx->SetType(out_name, framework::proto::VarType::LOD_TENSOR_ARRAY);
108+
if (ctx->HasVar(x_name)) {
109+
ctx->SetDataType(out_name, ctx->GetDataType(x_name));
110110
}
111111
}
112112
};

0 commit comments

Comments
 (0)