Skip to content

Commit cb2d33a

Browse files
committed
resolve conflict
test=develop
1 parent 25123a3 commit cb2d33a

File tree

3 files changed

+11
-26
lines changed

3 files changed

+11
-26
lines changed

paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "paddle/fluid/framework/details/computation_op_handle.h"
1717
#include "paddle/fluid/framework/details/multi_devices_helper.h"
1818
#include "paddle/fluid/framework/details/op_graph_view.h"
19+
#include "paddle/fluid/framework/ir/graph_helper.h"
1920

2021
namespace paddle {
2122
namespace framework {
@@ -35,10 +36,10 @@ static bool IsLockAndRecordEventFreeComputationOpHandle(
3536

3637
std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl(
3738
std::unique_ptr<ir::Graph> ir_graph) const {
38-
auto &all_ops = ir_graph->Get<GraphOps>(kGraphOps);
39+
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*ir_graph);
3940
OpGraphView graph_view(all_ops);
4041
for (auto &op : all_ops) {
41-
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get());
42+
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
4243
if (compute_op == nullptr) continue;
4344
bool is_lock_and_record_event_free =
4445
IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph_view);

paddle/fluid/framework/details/op_graph_view.cc

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,16 @@ namespace paddle {
2020
namespace framework {
2121
namespace details {
2222

23-
OpGraphView::OpGraphView(
24-
const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
25-
Build(ops);
26-
}
23+
OpGraphView::OpGraphView(const std::vector<OpHandleBase *> &ops) { Build(ops); }
2724

28-
void OpGraphView::Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
25+
void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) {
2926
for (auto &op : ops) {
30-
preceding_ops_[op.get()];
31-
pending_ops_[op.get()];
27+
preceding_ops_[op];
28+
pending_ops_[op];
3229
for (auto &var : op->Outputs()) {
3330
for (auto &pending_op : var->PendingOps()) {
34-
preceding_ops_[pending_op].insert(op.get());
35-
pending_ops_[op.get()].insert(pending_op);
31+
preceding_ops_[pending_op].insert(op);
32+
pending_ops_[op].insert(pending_op);
3633
}
3734
}
3835
}
@@ -41,8 +38,6 @@ void OpGraphView::Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
4138
"There are duplicate ops in graph.");
4239
}
4340

44-
size_t OpGraphView::OpNumber() const { return preceding_ops_.size(); }
45-
4641
std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const {
4742
std::unordered_set<OpHandleBase *> ret;
4843
for (auto &pair : preceding_ops_) {
@@ -60,12 +55,6 @@ void OpGraphView::EnforceHasOp(OpHandleBase *op) const {
6055
op == nullptr ? "nullptr" : op->DebugString());
6156
}
6257

63-
const std::unordered_set<OpHandleBase *> &OpGraphView::PrecedingOps(
64-
OpHandleBase *op) const {
65-
EnforceHasOp(op);
66-
return preceding_ops_.at(op);
67-
}
68-
6958
const std::unordered_set<OpHandleBase *> &OpGraphView::PendingOps(
7059
OpHandleBase *op) const {
7160
EnforceHasOp(op);

paddle/fluid/framework/details/op_graph_view.h

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,16 @@ namespace details {
2626

2727
class OpGraphView {
2828
public:
29-
explicit OpGraphView(const std::vector<std::unique_ptr<OpHandleBase>> &ops);
30-
31-
size_t OpNumber() const;
29+
explicit OpGraphView(const std::vector<OpHandleBase *> &ops);
3230

3331
std::unordered_set<OpHandleBase *> AllOps() const;
3432

35-
const std::unordered_set<OpHandleBase *> &PrecedingOps(
36-
OpHandleBase *op) const;
37-
3833
const std::unordered_set<OpHandleBase *> &PendingOps(OpHandleBase *op) const;
3934

4035
bool HasOp(OpHandleBase *op) const;
4136

4237
private:
43-
void Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops);
38+
void Build(const std::vector<OpHandleBase *> &ops);
4439
void EnforceHasOp(OpHandleBase *op) const;
4540

4641
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>

0 commit comments

Comments
 (0)