Skip to content

Commit ab91046

Browse files
authored
Merge pull request #9934 from reyoung/feature/PolishCreateOpHandleIOs
CreateOpHandleIOs
2 parents 89727c9 + 63f9215 commit ab91046

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,21 +55,21 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
5555
}
5656
}
5757

58-
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op,
58+
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
59+
const OpDesc &op,
5960
const platform::Place &p,
6061
const size_t &i) const {
6162
auto *op_handle = result->ops_.back().get();
62-
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
63-
platform::DeviceContextPool::Instance().Get(p));
63+
op_handle->dev_ctxes_[p] = platform::DeviceContextPool::Instance().Get(p);
6464

65-
auto var_names = op->InputArgumentNames();
65+
auto var_names = op.InputArgumentNames();
6666

6767
for (auto &each_var_name : var_names) {
6868
VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i);
6969
op_handle->AddInput(var);
7070
}
7171

72-
var_names = op->OutputArgumentNames();
72+
var_names = op.OutputArgumentNames();
7373

7474
for (auto &each_var_name : var_names) {
7575
CreateOpOutput(result, op_handle, each_var_name, p, i);
@@ -107,7 +107,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
107107
result.ops_.emplace_back(new SendOpHandle(*op, s, p));
108108
// Create inputs for output on original place and no ssa output
109109
// is created for send op.
110-
CreateOpHandleIOs(&result, op, p, 0);
110+
CreateOpHandleIOs(&result, *op, p, 0);
111111
continue;
112112
}
113113

@@ -117,7 +117,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
117117

118118
result.ops_.emplace_back(new ComputationOpHandle(*op, s, p));
119119
auto *op_handle = result.ops_.back().get();
120-
CreateOpHandleIOs(&result, op, p, i);
120+
CreateOpHandleIOs(&result, *op, p, i);
121121

122122
auto var_names = op->OutputArgumentNames();
123123

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
4545
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
4646

4747
private:
48-
void CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p,
49-
const size_t &i) const;
48+
void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op,
49+
const platform::Place &p, const size_t &i) const;
5050

5151
private:
5252
std::string loss_var_name_;

0 commit comments

Comments
 (0)