Skip to content

Commit ce08dc8

Browse files
committed
have stream removed error
1 parent 0bf799a commit ce08dc8

File tree

6 files changed

+24
-31
lines changed

6 files changed

+24
-31
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -57,19 +57,24 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
5757

5858
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op,
5959
const platform::Place &p,
60-
const size_t &i) const {
60+
const size_t &i,
61+
bool create_output) const {
6162
auto *op_handle = result->ops_.back().get();
63+
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
64+
platform::DeviceContextPool::Instance().Get(p));
6265

6366
auto var_names = op->InputArgumentNames();
6467

6568
for (auto &each_var_name : var_names) {
6669
VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i);
6770
op_handle->AddInput(var);
6871
}
69-
var_names = op->OutputArgumentNames();
72+
if (create_output) {
73+
var_names = op->OutputArgumentNames();
7074

71-
for (auto &each_var_name : var_names) {
72-
CreateOpOutput(result, op_handle, each_var_name, p, i);
75+
for (auto &each_var_name : var_names) {
76+
CreateOpOutput(result, op_handle, each_var_name, p, i);
77+
}
7378
}
7479
}
7580

@@ -100,9 +105,11 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
100105
if (!is_forwarding && op->Type() == "send") {
101106
auto &p = places_[0];
102107
auto *s = local_scopes_[0];
103-
size_t i = 0;
104-
result.ops_.emplace_back(new SendOpHandle(*op, s, p));
105-
CreateOpHandleIOs(&result, op, p, i);
108+
// FIXME(wuyi): send op always copy from GPU 0
109+
result.ops_.emplace_back(new SendOpHandle(*op, s));
110+
// Create inputs for output on original place and no ssa output
111+
// is created for send op.
112+
CreateOpHandleIOs(&result, op, p, 0, false);
106113
continue;
107114
}
108115

@@ -112,23 +119,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
112119

113120
result.ops_.emplace_back(new ComputationOpHandle(*op, s, p));
114121
auto *op_handle = result.ops_.back().get();
115-
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
116-
platform::DeviceContextPool::Instance().Get(p));
117-
118122
CreateOpHandleIOs(&result, op, p, i);
119-
// auto var_names = op->InputArgumentNames();
120123

121-
// for (auto &each_var_name : var_names) {
122-
// VarHandle *var =
123-
// CreateOrGetLatestVarHandle(&result, each_var_name, p, i);
124-
// op_handle->AddInput(var);
125-
// }
126124
auto var_names = op->OutputArgumentNames();
127125

128-
// for (auto &each_var_name : var_names) {
129-
// CreateOpOutput(&result, op_handle, each_var_name, p, i);
130-
// }
131-
132126
if (is_forwarding) {
133127
if (var_names.size() == 1 && var_names[0] == loss_var_name_) {
134128
// Insert ScaleCost OpHandle

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
4646

4747
private:
4848
void CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p,
49-
const size_t &i) const;
49+
const size_t &i, bool create_output = true) const;
5050

5151
private:
5252
std::string loss_var_name_;

paddle/fluid/framework/details/send_op_handle.cc

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,18 @@ namespace framework {
1919
namespace details {
2020

2121
SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc,
22-
const Scope *local_scope,
23-
const platform::Place &place)
22+
const Scope *local_scope)
2423
: op_(framework::OpRegistry::CreateOp(op_desc)),
25-
local_scope_(local_scope),
26-
place_(place) {}
24+
local_scope_(local_scope) {}
2725

2826
void SendOpHandle::RunImpl() {
2927
// Wait input done
3028
for (auto *in : inputs_) {
3129
auto &p = static_cast<VarHandle *>(in)->place_;
3230
in->generated_op_->Wait(dev_ctxes_[p]);
3331
}
34-
35-
op_->Run(*local_scope_, place_);
32+
platform::CPUPlace cpu;
33+
op_->Run(*local_scope_, cpu);
3634
}
3735

3836
std::string SendOpHandle::Name() const { return "send"; }

paddle/fluid/framework/details/send_op_handle.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,8 @@ namespace details {
3131
struct SendOpHandle : public OpHandleBase {
3232
std::unique_ptr<OperatorBase> op_;
3333
const Scope* local_scope_;
34-
const platform::Place& place_;
3534

36-
SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
37-
const platform::Place& place);
35+
SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope);
3836

3937
std::string Name() const override;
4038

python/paddle/fluid/distribute_transpiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ def transpile(self,
255255
def get_trainer_program(self):
256256
# remove optimize ops and add a send op to main_program
257257
self.program.global_block().delete_ops(self.optimize_ops)
258+
self.program.sync_with_cpp()
258259
# FIXME(typhoonzero): serialize once will fix error occurs when clone.
259260
self.program.__str__()
260261
return self.program

python/paddle/fluid/parallel_executor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ def __init__(self,
101101

102102
self.persistable_vars = [
103103
v.name
104-
for v in filter(lambda var: var.persistable, main.list_vars())
104+
for v in filter(lambda var: \
105+
var.persistable and var.type != core.VarDesc.VarType.RAW,
106+
main.list_vars())
105107
]
106108

107109
self.executor = core.ParallelExecutor(

0 commit comments

Comments
 (0)