Skip to content

Commit 37e5144

Browse files
committed
op compose node and update nodes.
1 parent 9605fcd commit 37e5144

29 files changed

+262
-153
lines changed

paddle/fluid/framework/details/all_reduce_op_handle.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,25 @@ namespace framework {
2323
namespace details {
2424

2525
#ifdef PADDLE_WITH_CUDA
26-
AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
26+
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
27+
const std::vector<Scope *> &local_scopes,
2728
const std::vector<platform::Place> &places,
2829
const platform::NCCLContextMap *ctxs)
29-
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) {
30+
: OpHandleBase(node),
31+
local_scopes_(local_scopes),
32+
places_(places),
33+
nccl_ctxs_(ctxs) {
3034
if (nccl_ctxs_) {
3135
for (auto &p : places_) {
3236
this->dev_ctxes_[p] = nccl_ctxs_->DevCtx(p);
3337
}
3438
}
3539
}
3640
#else
37-
AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
41+
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
42+
const std::vector<Scope *> &local_scopes,
3843
const std::vector<platform::Place> &places)
39-
: local_scopes_(local_scopes), places_(places) {}
44+
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
4045
#endif
4146

4247
void AllReduceOpHandle::RunImpl() {

paddle/fluid/framework/details/all_reduce_op_handle.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ namespace details {
3030

3131
struct AllReduceOpHandle : public OpHandleBase {
3232
#ifdef PADDLE_WITH_CUDA
33-
AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
33+
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
3434
const std::vector<platform::Place> &places,
3535
const platform::NCCLContextMap *ctxs);
3636
#else
37-
AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
37+
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
3838
const std::vector<platform::Place> &places);
3939
#endif
4040
std::string Name() const override;

paddle/fluid/framework/details/broadcast_op_handle.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,23 @@ namespace details {
3535
struct BroadcastOpHandle : public OpHandleBase {
3636
public:
3737
#ifdef PADDLE_WITH_CUDA
38-
BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
38+
BroadcastOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
3939
const std::vector<platform::Place> &places,
4040
const platform::NCCLContextMap *nccl_ctxs)
41-
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(nccl_ctxs) {
41+
: OpHandleBase(node),
42+
local_scopes_(local_scopes),
43+
places_(places),
44+
nccl_ctxs_(nccl_ctxs) {
4245
if (nccl_ctxs_) {
4346
for (auto &p_ctx : nccl_ctxs_->contexts_) {
4447
dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get();
4548
}
4649
}
4750
}
4851
#else
49-
BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
52+
BroadcastOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
5053
const std::vector<platform::Place> &places)
51-
: local_scopes_(local_scopes), places_(places) {}
54+
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
5255
#endif
5356

5457
std::string Name() const override;

paddle/fluid/framework/details/broadcast_op_handle_test.cc

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -96,48 +96,56 @@ struct TestBroadcastOpHandle {
9696
}
9797
param_scopes_[input_scope_idx]->Var("input");
9898

99+
std::unique_ptr<ir::Node> n(new ir::Node(ir::Node::Type::kOperation));
99100
if (use_gpu_) {
100101
#ifdef PADDLE_WITH_CUDA
101-
op_handle_.reset(
102-
new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
102+
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_,
103+
nccl_ctxs_.get()));
103104
#else
104105
PADDLE_THROW("CUDA is not support.");
105106
#endif
106107
} else {
107108
#ifdef PADDLE_WITH_CUDA
108-
op_handle_.reset(
109-
new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
109+
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_,
110+
nccl_ctxs_.get()));
110111
#else
111-
op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_));
112+
op_handle_.reset(
113+
new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_));
112114
#endif
113115
}
114116

115-
auto* in_var_handle =
116-
new VarHandle(1, input_scope_idx, "input", gpu_list_[input_scope_idx]);
117+
std::unique_ptr<ir::Node> v(new ir::Node(ir::Node::Type::kVariable));
118+
auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input",
119+
gpu_list_[input_scope_idx]);
117120
vars_.emplace_back(in_var_handle);
118121
op_handle_->AddInput(in_var_handle);
119122

120123
// add dummy var
121-
vars_.emplace_back(new DummyVarHandle());
124+
125+
std::unique_ptr<ir::Node> v2(new ir::Node(ir::Node::Type::kVariable));
126+
vars_.emplace_back(new DummyVarHandle(v2.get()));
122127
DummyVarHandle* dummy_var_handle =
123128
static_cast<DummyVarHandle*>(vars_.back().get());
124-
dummy_var_handle->generated_op_ = nullptr;
129+
dummy_var_handle->ClearGeneratedOp();
125130
op_handle_->AddInput(dummy_var_handle);
126131

127132
for (size_t j = 0; j < gpu_list_.size(); ++j) {
128133
if (!use_gpu_) {
129134
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
130135
}
131-
VarHandle* out_var_handle = new VarHandle(2, j, "out", gpu_list_[j]);
136+
std::unique_ptr<ir::Node> v3(new ir::Node(ir::Node::Type::kVariable));
137+
VarHandle* out_var_handle =
138+
new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]);
132139
vars_.emplace_back(out_var_handle);
133140
op_handle_->AddOutput(out_var_handle);
134141
}
135142

136143
// add dummy var
137-
vars_.emplace_back(new DummyVarHandle());
144+
std::unique_ptr<ir::Node> v4(new ir::Node(ir::Node::Type::kVariable));
145+
vars_.emplace_back(new DummyVarHandle(v4.get()));
138146
DummyVarHandle* out_dummy_var_handle =
139147
static_cast<DummyVarHandle*>(vars_.back().get());
140-
out_dummy_var_handle->generated_op_ = nullptr;
148+
out_dummy_var_handle->ClearGeneratedOp();
141149
op_handle_->AddOutput(out_dummy_var_handle);
142150
}
143151

paddle/fluid/framework/details/computation_op_handle.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@
1919
namespace paddle {
2020
namespace framework {
2121
namespace details {
22-
ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
23-
platform::Place place)
24-
: op_(framework::OpRegistry::CreateOp(op_desc)),
22+
ComputationOpHandle::ComputationOpHandle(ir::Node *node, const OpDesc &op_desc,
23+
Scope *scope, platform::Place place)
24+
: OpHandleBase(node),
25+
op_(framework::OpRegistry::CreateOp(op_desc)),
2526
scope_(scope),
2627
place_(place) {}
2728

@@ -35,8 +36,8 @@ void ComputationOpHandle::RunImpl() {
3536

3637
bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) {
3738
bool need_wait =
38-
in_var && in_var->generated_op_ &&
39-
in_var->generated_op_->DeviceContext(place_) != dev_ctxes_[place_];
39+
in_var && in_var->GeneratedOp() &&
40+
in_var->GeneratedOp()->DeviceContext(place_) != dev_ctxes_[place_];
4041
return need_wait;
4142
}
4243

paddle/fluid/framework/details/computation_op_handle.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ namespace framework {
2828
namespace details {
2929
struct ComputationOpHandle : public OpHandleBase {
3030
public:
31-
ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
31+
ComputationOpHandle(ir::Node *node, const OpDesc &op_desc, Scope *scope,
3232
platform::Place place);
3333

3434
std::string Name() const override;

paddle/fluid/framework/details/data_balance_op_handle.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ namespace details {
2222

2323
#ifdef PADDLE_WITH_CUDA
2424
DataBalanceOpHandle::DataBalanceOpHandle(
25-
const std::vector<Scope *> &local_scopes,
25+
ir::Node *node, const std::vector<Scope *> &local_scopes,
2626
const std::vector<platform::Place> &places,
2727
const platform::NCCLContextMap *ctxs)
28-
: local_scopes_(local_scopes), places_(places) {
28+
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {
2929
if (ctxs) {
3030
for (auto &p : places_) {
3131
this->dev_ctxes_[p] = ctxs->DevCtx(p);
@@ -34,9 +34,9 @@ DataBalanceOpHandle::DataBalanceOpHandle(
3434
}
3535
#else
3636
DataBalanceOpHandle::DataBalanceOpHandle(
37-
const std::vector<Scope *> &local_scopes,
37+
ir::Node *node, const std::vector<Scope *> &local_scopes,
3838
const std::vector<platform::Place> &places)
39-
: local_scopes_(local_scopes), places_(places) {}
39+
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
4040
#endif
4141

4242
std::string DataBalanceOpHandle::Name() const { return "data balance"; }

paddle/fluid/framework/details/data_balance_op_handle.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ namespace details {
3030
struct DataBalanceOpHandle : public OpHandleBase {
3131
public:
3232
#ifdef PADDLE_WITH_CUDA
33-
DataBalanceOpHandle(const std::vector<Scope *> &local_scopes,
33+
DataBalanceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
3434
const std::vector<platform::Place> &places,
3535
const platform::NCCLContextMap *ctxs);
3636
#else
37-
DataBalanceOpHandle(const std::vector<Scope *> &local_scopes,
37+
DataBalanceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
3838
const std::vector<platform::Place> &places);
3939
#endif
4040

paddle/fluid/framework/details/fetch_op_handle.cc

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,16 @@ namespace paddle {
2121
namespace framework {
2222
namespace details {
2323

24-
FetchOpHandle::FetchOpHandle(FeedFetchList *data, size_t offset,
24+
FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
2525
std::vector<Scope *> *local_scopes)
26-
: data_(data), offset_(offset), local_scopes_(local_scopes) {}
26+
: OpHandleBase(node),
27+
data_(data),
28+
offset_(offset),
29+
local_scopes_(local_scopes) {}
2730

2831
FetchOpHandle::~FetchOpHandle() {
2932
for (auto *input_var : inputs_) {
30-
input_var->pending_ops_.erase(this);
33+
input_var->RemoveOutput(this, this->Node());
3134
}
3235
}
3336

@@ -77,8 +80,8 @@ void FetchOpHandle::RunImpl() {
7780
void FetchOpHandle::WaitInputVarGenerated(const platform::Place &place) {
7881
auto cpu_ctx = platform::DeviceContextPool::Instance().Get(place);
7982
for (auto *input : inputs_) {
80-
if (input->generated_op_) {
81-
input->generated_op_->RecordWaitEventOnCtx(cpu_ctx);
83+
if (input->GeneratedOp()) {
84+
input->GeneratedOp()->RecordWaitEventOnCtx(cpu_ctx);
8285
}
8386
}
8487
}

paddle/fluid/framework/details/fetch_op_handle.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ namespace details {
2828

2929
struct FetchOpHandle : public OpHandleBase {
3030
public:
31-
FetchOpHandle(FeedFetchList *data, size_t offset,
31+
FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
3232
std::vector<Scope *> *local_scopes);
3333

3434
~FetchOpHandle();

0 commit comments

Comments
 (0)