Skip to content

Commit 0aa9546

Browse files
author
Yancey
authored
fix dist train error (#11281)
* fix dist train error * update by comment
1 parent 8fa457f commit 0aa9546

File tree

3 files changed

+8
-9
lines changed

3 files changed

+8
-9
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -464,9 +464,8 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
464464

465465
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
466466
const OpDesc &op) const {
467-
auto &p = places_[0];
468-
auto *s = local_scopes_[0];
469-
result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
467+
result->ops_.emplace_back(
468+
new RPCOpHandle(op, local_scopes_[0], op.Type(), places_[0]));
470469

471470
if (op.Type() == "send_barrier") {
472471
ConnectOp(result, result->ops_.back().get(), "send_vars");

paddle/fluid/framework/details/rpc_op_handle.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ namespace framework {
1919
namespace details {
2020

2121
RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc,
22-
const Scope *local_scope, const platform::Place &place,
23-
const std::string &name)
22+
const Scope *local_scope, const std::string &name,
23+
const platform::Place &place)
2424
: op_(framework::OpRegistry::CreateOp(op_desc)),
2525
local_scope_(local_scope),
26-
place_(place),
27-
name_(name) {}
26+
name_(name),
27+
place_(place) {}
2828

2929
void RPCOpHandle::RunImpl() {
3030
// TODO(wuyi): need further analysis whether wait VarDummyHandle.

paddle/fluid/framework/details/rpc_op_handle.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ namespace details {
2929

3030
struct RPCOpHandle : public OpHandleBase {
3131
RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
32-
const platform::Place& place, const std::string& name);
32+
const std::string& name, const platform::Place& place);
3333

3434
std::string Name() const override;
3535

@@ -43,8 +43,8 @@ struct RPCOpHandle : public OpHandleBase {
4343
private:
4444
std::unique_ptr<OperatorBase> op_;
4545
const Scope* local_scope_;
46-
const platform::Place& place_;
4746
const std::string name_;
47+
platform::Place place_;
4848
};
4949

5050
} // namespace details

0 commit comments

Comments
 (0)