Skip to content

Commit 16a9dfe

Browse files
committed
finish
1 parent ec69768 commit 16a9dfe

File tree

4 files changed

+19
-15
lines changed

4 files changed

+19
-15
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
5757

5858
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op,
5959
const platform::Place &p,
60-
const size_t &i,
61-
bool create_output) const {
60+
const size_t &i) const {
6261
auto *op_handle = result->ops_.back().get();
6362
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
6463
platform::DeviceContextPool::Instance().Get(p));
@@ -69,12 +68,11 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op,
6968
VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i);
7069
op_handle->AddInput(var);
7170
}
72-
if (create_output) {
73-
var_names = op->OutputArgumentNames();
7471

75-
for (auto &each_var_name : var_names) {
76-
CreateOpOutput(result, op_handle, each_var_name, p, i);
77-
}
72+
var_names = op->OutputArgumentNames();
73+
74+
for (auto &each_var_name : var_names) {
75+
CreateOpOutput(result, op_handle, each_var_name, p, i);
7876
}
7977
}
8078

@@ -106,10 +104,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
106104
auto &p = places_[0];
107105
auto *s = local_scopes_[0];
108106
// FIXME(wuyi): send op always copy from GPU 0
109-
result.ops_.emplace_back(new SendOpHandle(*op, s));
107+
result.ops_.emplace_back(new SendOpHandle(*op, s, p));
110108
// Create inputs for output on original place and no ssa output
111109
// is created for send op.
112-
CreateOpHandleIOs(&result, op, p, 0, false);
110+
CreateOpHandleIOs(&result, op, p, 0);
113111
continue;
114112
}
115113

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
4646

4747
private:
4848
void CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p,
49-
const size_t &i, bool create_output = true) const;
49+
const size_t &i) const;
5050

5151
private:
5252
std::string loss_var_name_;

paddle/fluid/framework/details/send_op_handle.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,22 @@ namespace framework {
1919
namespace details {
2020

2121
SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc,
22-
const Scope *local_scope)
22+
const Scope *local_scope,
23+
const platform::Place &place)
2324
: op_(framework::OpRegistry::CreateOp(op_desc)),
24-
local_scope_(local_scope) {}
25+
local_scope_(local_scope),
26+
place_(place) {}
2527

2628
void SendOpHandle::RunImpl() {
2729
// Wait input done
2830
for (auto *in : inputs_) {
2931
auto &p = static_cast<VarHandle *>(in)->place_;
32+
if (in->DebugString() == "dummy") { // HACK
33+
continue;
34+
}
3035
in->generated_op_->Wait(dev_ctxes_[p]);
3136
}
32-
platform::CPUPlace cpu;
33-
op_->Run(*local_scope_, cpu);
37+
op_->Run(*local_scope_, place_);
3438
}
3539

3640
std::string SendOpHandle::Name() const { return "send"; }

paddle/fluid/framework/details/send_op_handle.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ namespace details {
3131
struct SendOpHandle : public OpHandleBase {
3232
std::unique_ptr<OperatorBase> op_;
3333
const Scope* local_scope_;
34+
const platform::Place& place_;
3435

35-
SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope);
36+
SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
37+
const platform::Place& place);
3638

3739
std::string Name() const override;
3840

0 commit comments

Comments
 (0)