Skip to content

Commit d74bfef

Browse files
authored
polish code of pass and executor (#56886)
* polish code of pass and executor * update ut
1 parent 061bb9d commit d74bfef

File tree

4 files changed

+41
-40
lines changed

4 files changed

+41
-40
lines changed

paddle/fluid/framework/new_executor/program_interpreter.cc

Lines changed: 38 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -123,20 +123,41 @@ void ProgramInterpreter::RunImpl() {
123123
#endif
124124
}
125125

126-
FetchList ProgramInterpreter::Run(
127-
const std::vector<std::string>& feed_names,
128-
const std::vector<phi::DenseTensor>& feed_tensors) {
126+
FetchList ProgramInterpreter::Run(const std::vector<std::string>& feed_names,
127+
bool need_fetch) {
129128
SetDeviceId(place_);
130129
CheckCUDAGraphBeforeRun(feed_names);
131130

132131
#ifdef PADDLE_WITH_DNNL
133132
platform::AttachPointerHashToMKLDNNKey(this, place_);
134133
#endif
135134

136-
bool is_build = is_build_;
137-
Prepare(feed_names, feed_tensors, is_build);
135+
if (!is_build_) {
136+
LOG_FIRST_N(INFO, 1) << "New Executor is Running.";
137+
paddle::framework::interpreter::BuildVariableScope(
138+
block_, execution_config_, &var_scope_);
138139

139-
if (is_build) {
140+
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
141+
paddle::framework::interpreter::BuildOpFuncList(
142+
place_,
143+
block_,
144+
execution_config_.skip_gc_vars,
145+
&op_func_nodes,
146+
&var_scope_,
147+
execution_config_,
148+
HasLocalScope(),
149+
static_build_);
150+
SetFeedVarsInplaceSkip(feed_names);
151+
// convert vec func_list to graph
152+
Convert(&op_func_nodes);
153+
UpdateSyncOpNum();
154+
if (static_build_) {
155+
VLOG(4) << "RUN impl";
156+
RunImpl();
157+
}
158+
is_build_ = true;
159+
is_shared_results_build_ = true;
160+
} else {
140161
RunImpl();
141162
}
142163

@@ -145,8 +166,10 @@ FetchList ProgramInterpreter::Run(
145166
}
146167

147168
// return Fetch Tensors
148-
auto* fetch_var = local_scope_->FindVar(interpreter::kFetchVarName);
149-
if (fetch_var) {
169+
Scope* inner_scope =
170+
HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope();
171+
auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName);
172+
if (fetch_var && need_fetch) {
150173
auto fetch_list = std::move(*fetch_var->GetMutable<framework::FetchList>());
151174
#ifdef PADDLE_WITH_CUDA
152175
if (platform::IsCUDAGraphCapturing()) {
@@ -162,41 +185,20 @@ FetchList ProgramInterpreter::Run(
162185
}
163186
}
164187

165-
FetchList ProgramInterpreter::Run(const std::vector<std::string>& feed_names,
166-
bool need_fetch) {
188+
FetchList ProgramInterpreter::Run(
189+
const std::vector<std::string>& feed_names,
190+
const std::vector<phi::DenseTensor>& feed_tensors) {
167191
SetDeviceId(place_);
168192
CheckCUDAGraphBeforeRun(feed_names);
169193

170194
#ifdef PADDLE_WITH_DNNL
171195
platform::AttachPointerHashToMKLDNNKey(this, place_);
172196
#endif
173197

174-
if (!is_build_) {
175-
LOG_FIRST_N(INFO, 1) << "New Executor is Running.";
176-
paddle::framework::interpreter::BuildVariableScope(
177-
block_, execution_config_, &var_scope_);
198+
bool is_build = is_build_;
199+
Prepare(feed_names, feed_tensors, is_build);
178200

179-
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
180-
paddle::framework::interpreter::BuildOpFuncList(
181-
place_,
182-
block_,
183-
execution_config_.skip_gc_vars,
184-
&op_func_nodes,
185-
&var_scope_,
186-
execution_config_,
187-
HasLocalScope(),
188-
static_build_);
189-
SetFeedVarsInplaceSkip(feed_names);
190-
// convert vec func_list to graph
191-
Convert(&op_func_nodes);
192-
UpdateSyncOpNum();
193-
if (static_build_) {
194-
VLOG(4) << "RUN impl";
195-
RunImpl();
196-
}
197-
is_build_ = true;
198-
is_shared_results_build_ = true;
199-
} else {
201+
if (is_build) {
200202
RunImpl();
201203
}
202204

@@ -208,7 +210,7 @@ FetchList ProgramInterpreter::Run(const std::vector<std::string>& feed_names,
208210
Scope* inner_scope =
209211
HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope();
210212
auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName);
211-
if (fetch_var && need_fetch) {
213+
if (fetch_var) {
212214
auto fetch_list = std::move(*fetch_var->GetMutable<framework::FetchList>());
213215
#ifdef PADDLE_WITH_CUDA
214216
if (platform::IsCUDAGraphCapturing()) {

paddle/fluid/pybind/ir.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ void BindPassManager(pybind11::module *m) {
507507
},
508508
py::arg("opt_level") = 2)
509509
.def("add_pass",
510-
[](PassManager &self, std::string pass_name) {
510+
[](PassManager &self, const std::string &pass_name) {
511511
self.AddPass(
512512
std::move(ir::PassRegistry::Instance().Get(pass_name)));
513513
})

paddle/ir/transforms/dead_code_elimination_pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ namespace {
2626
// Now just a naive implementation.
2727
class DeadCodeEliminationPass : public ir::Pass {
2828
public:
29-
DeadCodeEliminationPass() : ir::Pass("DeadCodeEliminationPass", 0) {}
29+
DeadCodeEliminationPass() : ir::Pass("dead_code_elimination", 0) {}
3030

3131
void Run(ir::Operation *op) override {
3232
auto module_op = op->dyn_cast<ir::ModuleOp>();

test/ir/new_ir/test_pass_manager.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ def test_op(self):
5656
pm.run(new_program)
5757
op_names = [op.name() for op in new_program.block().ops]
5858
# print(op_names)
59-
# TODO(zhiqiu): unify the name of pass
60-
self.assertEqual(pm.passes(), ['DeadCodeEliminationPass'])
59+
self.assertEqual(pm.passes(), ['dead_code_elimination'])
6160
self.assertFalse(pm.empty())
6261
self.assertTrue(
6362
'pd.uniform' not in op_names

0 commit comments

Comments
 (0)