Skip to content

Commit dd343a4

Browse files
committed
Merge remote-tracking branch 'ups/develop' into fea/jit/vadd
2 parents b68ecec + fcbe84c commit dd343a4

30 files changed

+734
-206
lines changed

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto node)
22
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor)
3+
cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base)
34
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
45
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
56
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
@@ -30,7 +31,9 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_
3031
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
3132
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
3233

33-
if(WITH_GPU)
34+
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
35+
36+
if (WITH_GPU)
3437
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle
3538
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
3639
endif()
@@ -40,12 +43,13 @@ cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS grap
4043
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
4144
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
4245

43-
if(WITH_GPU)
44-
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass sequential_execution_pass)
45-
else()
46-
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto sequential_execution_pass)
46+
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass)
47+
if (WITH_GPU)
48+
list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
4749
endif()
4850

51+
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
52+
4953
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
5054
simple_threadpool device_context)
5155

paddle/fluid/framework/details/build_strategy.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
6969

7070
// Verify that the graph is correct for multi-device executor.
7171
AppendPass("multi_devices_check_pass");
72+
73+
if (strategy_.remove_unnecessary_lock_) {
74+
AppendPass("modify_op_lock_and_record_event_pass");
75+
}
7276
}
7377

7478
private:
@@ -136,3 +140,4 @@ USE_PASS(multi_devices_pass);
136140
USE_PASS(multi_devices_check_pass);
137141
USE_PASS(multi_devices_print_pass);
138142
USE_PASS(sequential_execution_pass);
143+
USE_PASS(modify_op_lock_and_record_event_pass);

paddle/fluid/framework/details/build_strategy.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ struct BuildStrategy {
7373

7474
bool fuse_broadcast_op_{false};
7575

76+
bool remove_unnecessary_lock_{false};
77+
7678
// User normally doesn't need to call this API.
7779
// The PassBuilder allows for more customized insert, remove of passes
7880
// from python side.

paddle/fluid/framework/details/computation_op_handle.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,15 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
2929
void ComputationOpHandle::RunImpl() {
3030
WaitInputVarGenerated(place_);
3131

32-
this->RunAndRecordEvent([this] {
32+
auto run_func = [this]() {
3333
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
34-
});
34+
};
35+
36+
if (is_lock_and_record_event_free_) {
37+
run_func();
38+
} else {
39+
this->RunAndRecordEvent(run_func);
40+
}
3541
}
3642

3743
bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) {

paddle/fluid/framework/details/computation_op_handle.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ struct ComputationOpHandle : public OpHandleBase {
3636

3737
const platform::Place &GetPlace() const { return place_; }
3838

39+
void SetLockAndRecordEventFree(bool b) { is_lock_and_record_event_free_ = b; }
40+
3941
protected:
4042
void RunImpl() override;
4143

@@ -45,6 +47,7 @@ struct ComputationOpHandle : public OpHandleBase {
4547
std::unique_ptr<OperatorBase> op_;
4648
Scope *scope_;
4749
platform::Place place_;
50+
bool is_lock_and_record_event_free_{false};
4851
};
4952
} // namespace details
5053
} // namespace framework
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h"
16+
#include "paddle/fluid/framework/details/computation_op_handle.h"
17+
#include "paddle/fluid/framework/details/multi_devices_helper.h"
18+
#include "paddle/fluid/framework/details/op_graph_view.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
namespace details {
23+
24+
static bool IsLockAndRecordEventFreeComputationOpHandle(
25+
ComputationOpHandle *op, const OpGraphView &graph_view) {
26+
if (!platform::is_gpu_place(op->GetPlace())) return false;
27+
for (auto &pending_op : graph_view.PendingOps(op)) {
28+
auto *tmp = dynamic_cast<ComputationOpHandle *>(pending_op);
29+
if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) {
30+
return false;
31+
}
32+
}
33+
return true;
34+
}
35+
36+
std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl(
37+
std::unique_ptr<ir::Graph> ir_graph) const {
38+
auto &all_ops = ir_graph->Get<GraphOps>(kGraphOps);
39+
OpGraphView graph_view(all_ops);
40+
for (auto &op : all_ops) {
41+
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get());
42+
if (compute_op == nullptr) continue;
43+
bool is_lock_and_record_event_free =
44+
IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph_view);
45+
compute_op->SetLockAndRecordEventFree(is_lock_and_record_event_free);
46+
if (is_lock_and_record_event_free) {
47+
VLOG(10) << "Set is_lock_and_record_event_free be true in op "
48+
<< compute_op->DebugString();
49+
}
50+
}
51+
return ir_graph;
52+
}
53+
54+
} // namespace details
55+
} // namespace framework
56+
} // namespace paddle
57+
58+
REGISTER_PASS(modify_op_lock_and_record_event_pass,
59+
paddle::framework::details::ModifyOpLockAndRecordEventPass);
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/framework/ir/graph.h"
18+
#include "paddle/fluid/framework/ir/pass.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
namespace details {
23+
24+
class ModifyOpLockAndRecordEventPass : public ir::Pass {
25+
protected:
26+
std::unique_ptr<ir::Graph> ApplyImpl(
27+
std::unique_ptr<ir::Graph> graph) const override;
28+
};
29+
30+
} // namespace details
31+
} // namespace framework
32+
} // namespace paddle
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/details/op_graph_view.h"
16+
#include <queue>
17+
#include <utility>
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace details {
22+
23+
OpGraphView::OpGraphView(
24+
const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
25+
Build(ops);
26+
}
27+
28+
void OpGraphView::Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
29+
for (auto &op : ops) {
30+
preceding_ops_[op.get()];
31+
pending_ops_[op.get()];
32+
for (auto &var : op->Outputs()) {
33+
for (auto &pending_op : var->PendingOps()) {
34+
preceding_ops_[pending_op].insert(op.get());
35+
pending_ops_[op.get()].insert(pending_op);
36+
}
37+
}
38+
}
39+
PADDLE_ENFORCE(
40+
preceding_ops_.size() == ops.size() && pending_ops_.size() == ops.size(),
41+
"There are duplicate ops in graph.");
42+
}
43+
44+
size_t OpGraphView::OpNumber() const { return preceding_ops_.size(); }
45+
46+
std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const {
47+
std::unordered_set<OpHandleBase *> ret;
48+
for (auto &pair : preceding_ops_) {
49+
ret.insert(pair.first);
50+
}
51+
return ret;
52+
}
53+
54+
bool OpGraphView::HasOp(OpHandleBase *op) const {
55+
return preceding_ops_.count(op) != 0;
56+
}
57+
58+
void OpGraphView::EnforceHasOp(OpHandleBase *op) const {
59+
PADDLE_ENFORCE(HasOp(op), "Cannot find op %s in OpGraphView",
60+
op == nullptr ? "nullptr" : op->DebugString());
61+
}
62+
63+
const std::unordered_set<OpHandleBase *> &OpGraphView::PrecedingOps(
64+
OpHandleBase *op) const {
65+
EnforceHasOp(op);
66+
return preceding_ops_.at(op);
67+
}
68+
69+
const std::unordered_set<OpHandleBase *> &OpGraphView::PendingOps(
70+
OpHandleBase *op) const {
71+
EnforceHasOp(op);
72+
return pending_ops_.at(op);
73+
}
74+
75+
} // namespace details
76+
} // namespace framework
77+
} // namespace paddle
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <memory>
18+
#include <unordered_map>
19+
#include <unordered_set>
20+
#include <vector>
21+
#include "paddle/fluid/framework/details/op_handle_base.h"
22+
23+
namespace paddle {
24+
namespace framework {
25+
namespace details {
26+
27+
class OpGraphView {
28+
public:
29+
explicit OpGraphView(const std::vector<std::unique_ptr<OpHandleBase>> &ops);
30+
31+
size_t OpNumber() const;
32+
33+
std::unordered_set<OpHandleBase *> AllOps() const;
34+
35+
const std::unordered_set<OpHandleBase *> &PrecedingOps(
36+
OpHandleBase *op) const;
37+
38+
const std::unordered_set<OpHandleBase *> &PendingOps(OpHandleBase *op) const;
39+
40+
bool HasOp(OpHandleBase *op) const;
41+
42+
private:
43+
void Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops);
44+
void EnforceHasOp(OpHandleBase *op) const;
45+
46+
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
47+
preceding_ops_;
48+
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
49+
pending_ops_;
50+
};
51+
52+
} // namespace details
53+
} // namespace framework
54+
} // namespace paddle

paddle/fluid/framework/details/reference_count_op_handle.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
5151
dev_ctx_ = static_cast<platform::CUDADeviceContext *>(
5252
platform::DeviceContextPool::Instance().Get(place));
5353
if (IsStreamGarabageCollector()) {
54-
PADDLE_ENFORCE(cudaSetDevice(place.device));
54+
platform::SetDeviceId(place.device);
5555
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
5656
}
5757

@@ -61,7 +61,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
6161
~ReferenceCountOpHandle() {
6262
if (IsStreamGarabageCollector()) {
6363
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace());
64-
PADDLE_ENFORCE(cudaSetDevice(gpu_place.device));
64+
platform::SetDeviceId(gpu_place.device);
6565
PADDLE_ENFORCE(cudaEventDestroy(event_));
6666
}
6767
}

0 commit comments

Comments
 (0)