Skip to content

Commit b9d7bd4

Browse files
authored
Merge branch 'develop' into remove/kwargs
2 parents 6d2ce74 + 6537b17 commit b9d7bd4

File tree

16 files changed

+605
-432
lines changed

16 files changed

+605
-432
lines changed

paddle/fluid/API.spec

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,12 @@ paddle.fluid.layers.relu ArgSpec(args=['x', 'name'], varargs=None, keywords=None
160160
paddle.fluid.layers.log ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
161161
paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
162162
paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,))
163+
paddle.fluid.layers.elu ArgSpec(args=['x', 'alpha', 'name'], varargs=None, keywords=None, defaults=(1.0, None))
164+
paddle.fluid.layers.relu6 ArgSpec(args=['x', 'threshold', 'name'], varargs=None, keywords=None, defaults=(6.0, None))
165+
paddle.fluid.layers.pow ArgSpec(args=['x', 'factor', 'name'], varargs=None, keywords=None, defaults=(1.0, None))
166+
paddle.fluid.layers.stanh ArgSpec(args=['x', 'scale_a', 'scale_b', 'name'], varargs=None, keywords=None, defaults=(0.6666666666666666, 1.7159, None))
167+
paddle.fluid.layers.hard_sigmoid ArgSpec(args=['x', 'slope', 'offset', 'name'], varargs=None, keywords=None, defaults=(0.2, 0.5, None))
168+
paddle.fluid.layers.swish ArgSpec(args=['x', 'beta', 'name'], varargs=None, keywords=None, defaults=(1.0, None))
163169
paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], varargs=None, keywords=None, defaults=(None, None))
164170
paddle.fluid.layers.brelu ArgSpec(args=['x', 't_min', 't_max', 'name'], varargs=None, keywords=None, defaults=(0.0, 24.0, None))
165171
paddle.fluid.layers.leaky_relu ArgSpec(args=['x', 'alpha', 'name'], varargs=None, keywords=None, defaults=(0.02, None))
@@ -260,12 +266,6 @@ paddle.fluid.layers.slice ArgSpec(args=[], varargs='args', keywords='kwargs', de
260266
paddle.fluid.layers.shape ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
261267
paddle.fluid.layers.maxout ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
262268
paddle.fluid.layers.softshrink ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
263-
paddle.fluid.layers.elu ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
264-
paddle.fluid.layers.relu6 ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
265-
paddle.fluid.layers.pow ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
266-
paddle.fluid.layers.stanh ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
267-
paddle.fluid.layers.hard_sigmoid ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
268-
paddle.fluid.layers.swish ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
269269
paddle.fluid.layers.sigmoid ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
270270
paddle.fluid.layers.logsigmoid ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
271271
paddle.fluid.layers.exp ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))

paddle/fluid/framework/details/cow_ptr.h

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,41 +20,79 @@ namespace paddle {
2020
namespace framework {
2121
namespace details {
2222

23-
template <class T>
24-
class COWPtr {
23+
// Change it to thread safe flags if needed.
24+
class ThreadUnsafeOwnershipFlags {
2525
public:
26-
typedef std::shared_ptr<T> RefPtr;
26+
explicit ThreadUnsafeOwnershipFlags(bool flag) : flag_(flag) {}
2727

28-
private:
29-
RefPtr m_sp;
28+
ThreadUnsafeOwnershipFlags(const ThreadUnsafeOwnershipFlags& other) = delete;
29+
ThreadUnsafeOwnershipFlags& operator=(
30+
const ThreadUnsafeOwnershipFlags& other) = delete;
31+
ThreadUnsafeOwnershipFlags(ThreadUnsafeOwnershipFlags&& other) = default;
3032

31-
void detach() {
32-
T* tmp = m_sp.get();
33-
if (!(tmp == nullptr || m_sp.unique())) {
34-
m_sp = RefPtr(new T(*tmp));
33+
void SetOwnership(bool flag) { flag_ = flag; }
34+
35+
// Invoke the callback if it is not owned.
36+
template <typename Callback>
37+
void AcquireOwnershipOnce(Callback acquire) {
38+
if (!flag_) {
39+
acquire();
40+
flag_ = true;
3541
}
3642
}
3743

38-
public:
39-
COWPtr() : m_sp(nullptr) {}
40-
explicit COWPtr(T* t) : m_sp(t) {}
41-
explicit COWPtr(const RefPtr& refptr) : m_sp(refptr) {}
44+
private:
45+
bool flag_;
46+
};
4247

43-
const T& Data() const { return operator*(); }
48+
// Copy-On-Write pointer.
49+
// It will hold a T* pointer, and only copy once when `MutableData` is invoked.
50+
//
51+
// The template parameter OwnershipFlags should have:
52+
// * a constructor takes a bool. True if own.
53+
// * SetOwnership(bool flag).
54+
// * AcquireOwnershipOnce(Callback). It will invoke the callback if it is not
55+
// owned.
56+
//
57+
// https://en.wikipedia.org/wiki/Copy-on-write
58+
template <typename T, typename OwnershipFlags = ThreadUnsafeOwnershipFlags>
59+
class COWPtr {
60+
public:
61+
// Ctor from raw pointer.
62+
explicit COWPtr(T* ptr) : payload_(ptr), ownership_{true} {}
4463

45-
T* MutableData() { return operator->(); }
64+
// Move methods. Steal ownership from origin
65+
COWPtr(COWPtr&& other)
66+
: payload_(other.payload_), ownership_{std::move(other.ownership_)} {}
67+
COWPtr& operator=(COWPtr&& origin) = default;
4668

47-
const T& operator*() const { return *m_sp; }
48-
T& operator*() {
49-
detach();
50-
return *m_sp;
69+
// Copy methods. Not own payload
70+
COWPtr(const COWPtr& other) : payload_(other.payload_), ownership_{false} {}
71+
COWPtr& operator=(const COWPtr& other) {
72+
payload_ = other.payload_;
73+
ownership_.SetOwnership(false);
74+
return *this;
5175
}
52-
const T* operator->() const { return m_sp.operator->(); }
53-
T* operator->() {
54-
detach();
55-
return m_sp.operator->();
76+
77+
// Access read only data.
78+
const T& Data() const { return *payload_; }
79+
80+
// Access mutable data. If the data is not owned, the data will be copied
81+
// before.
82+
T* MutableData() {
83+
ownership_.AcquireOwnershipOnce(
84+
[this] { payload_.reset(new T(*payload_)); });
85+
return payload_.get();
5686
}
87+
88+
private:
89+
// Actual data pointer.
90+
std::shared_ptr<T> payload_;
91+
92+
// Ownership flag.
93+
OwnershipFlags ownership_;
5794
};
95+
5896
} // namespace details
5997
} // namespace framework
6098
} // namespace paddle

paddle/fluid/framework/details/cow_ptr_test.cc

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,6 @@ TEST(COWPtr, all) {
3030
ASSERT_EQ(ptr2.Data(), 10);
3131
}
3232

33-
TEST(COWPtr, change_old) {
34-
COWPtr<int> ptr(new int{0});
35-
COWPtr<int> ptr2 = ptr;
36-
*ptr.MutableData() = 10;
37-
ASSERT_EQ(ptr2.Data(), 0);
38-
ASSERT_EQ(ptr.Data(), 10);
39-
}
40-
4133
} // namespace details
4234
} // namespace framework
4335
} // namespace paddle

paddle/fluid/framework/details/reference_count_op_handle.h

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "paddle/fluid/framework/details/op_handle_base.h"
2323
#include "paddle/fluid/framework/garbage_collector.h"
2424
#include "paddle/fluid/framework/scope.h"
25+
#include "paddle/fluid/framework/selected_rows.h"
2526
#include "paddle/fluid/framework/tensor.h"
2627

2728
namespace paddle {
@@ -46,17 +47,15 @@ class ReferenceCountOpHandle : public OpHandleBase {
4647
const std::vector<std::string> &var_names,
4748
GarbageCollector<Tensor> *gc,
4849
AtomicReferenceCountMap *ref_cnts)
49-
: OpHandleBase(node),
50-
scope_(scope),
51-
var_names_(var_names),
52-
gc_(gc),
53-
ref_cnts_(ref_cnts) {
50+
: OpHandleBase(node), scope_(scope), gc_(gc), ref_cnts_(ref_cnts) {
5451
dev_ctx_ = static_cast<platform::CUDADeviceContext *>(
5552
platform::DeviceContextPool::Instance().Get(place));
5653
if (IsStreamGarabageCollector()) {
5754
PADDLE_ENFORCE(cudaSetDevice(place.device));
5855
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
5956
}
57+
58+
for (auto &name : var_names) AddVar(name);
6059
}
6160

6261
~ReferenceCountOpHandle() {
@@ -69,19 +68,35 @@ class ReferenceCountOpHandle : public OpHandleBase {
6968

7069
std::string Name() const override { return "reference_count"; }
7170

71+
void AddVar(const std::string &name) {
72+
auto it = var_names_.find(name);
73+
if (it != var_names_.end())
74+
++(it->second);
75+
else
76+
var_names_[name] = 1;
77+
}
78+
7279
protected:
7380
void RunImpl() override {
7481
auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
75-
std::vector<LoDTensor *> tensors;
76-
for (auto &name : var_names_) {
82+
std::vector<Tensor *> tensors;
83+
for (auto &pair : var_names_) {
84+
auto &name = pair.first;
7785
auto it = ref_cnts_->find(name);
7886
if (it == ref_cnts_->end()) continue;
7987

8088
auto *var = exec_scope->FindVar(name);
81-
if (var == nullptr || !var->IsType<LoDTensor>()) continue;
82-
83-
if (it->second.fetch_sub(1) <= 1) {
84-
tensors.emplace_back(var->GetMutable<LoDTensor>());
89+
if (var == nullptr) continue;
90+
91+
if (var->IsType<LoDTensor>()) {
92+
if (it->second.fetch_sub(pair.second) <= pair.second) {
93+
tensors.emplace_back(var->GetMutable<LoDTensor>());
94+
}
95+
} else if (var->IsType<SelectedRows>()) {
96+
if (it->second.fetch_sub(pair.second) <= pair.second) {
97+
tensors.emplace_back(
98+
var->GetMutable<SelectedRows>()->mutable_value());
99+
}
85100
}
86101
}
87102

@@ -91,7 +106,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
91106
}
92107

93108
private:
94-
void ClearTensors(const std::vector<LoDTensor *> &tensors) {
109+
void ClearTensors(const std::vector<Tensor *> &tensors) {
95110
auto *gc = dynamic_cast<StreamGarbageCollector<Tensor> *>(gc_);
96111
if (gc != nullptr) {
97112
auto compute_stream = dev_ctx_->stream();
@@ -112,7 +127,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
112127

113128
const Scope *scope_;
114129
platform::CUDADeviceContext *dev_ctx_;
115-
std::vector<std::string> var_names_;
130+
std::unordered_map<std::string, int> var_names_;
116131
GarbageCollector<Tensor> *gc_; // not own
117132
AtomicReferenceCountMap *ref_cnts_; // not own
118133
cudaEvent_t event_;

paddle/fluid/framework/details/reference_count_pass.cc

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#include <queue>
1516
#include <string>
1617
#include <vector>
1718

@@ -23,6 +24,25 @@ namespace paddle {
2324
namespace framework {
2425
namespace details {
2526

27+
static ComputationOpHandle *FindNextComputationOpHandle(VarHandle *var_in) {
28+
std::queue<VarHandleBase *> queue;
29+
queue.push(var_in);
30+
do {
31+
auto *var = queue.front();
32+
queue.pop();
33+
for (auto *op : var->PendingOps()) {
34+
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
35+
if (compute_op != nullptr && compute_op->GetPlace() == var_in->place_) {
36+
return compute_op;
37+
}
38+
for (auto *out_var : op->Outputs()) {
39+
queue.push(out_var);
40+
}
41+
}
42+
} while (!queue.empty());
43+
return nullptr;
44+
}
45+
2646
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
2747
std::unique_ptr<ir::Graph> graph) const {
2848
auto &ref_cnts = Get<DeviceReferenceCountMap>(kGlobalReferenceCount);
@@ -34,6 +54,9 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
3454
// Step 2: Find all variables in non-computation ops which refers to variables
3555
// in computation ops
3656
std::unordered_set<std::string> names;
57+
std::unordered_map<OpHandleBase *, std::unique_ptr<ReferenceCountOpHandle>>
58+
compute_ref_cnt_map;
59+
3760
auto get_ref_cnts_from_compute_op = [&](
3861
const std::unique_ptr<OpHandleBase> &op,
3962
const std::vector<VarHandleBase *> &vars) {
@@ -54,15 +77,18 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
5477
VarDesc *var_desc = var_handle->Node()->Var();
5578
auto var_name = var_handle->Node()->Name();
5679

57-
// This is wierd but there is really some variables without var_desc
80+
// This is weird but there is really some variables without var_desc
5881
// in computation_op
5982
if (var_desc == nullptr) {
6083
if (compute_op->Node()->Op()->Block()->FindVar(var_name) == nullptr)
6184
continue;
6285
} else {
63-
if (var_desc->Persistable() ||
64-
var_desc->Proto()->type().type() != proto::VarType::LOD_TENSOR)
86+
if (var_desc->Persistable()) continue;
87+
auto var_type = var_desc->Proto()->type().type();
88+
if (var_type != proto::VarType::LOD_TENSOR &&
89+
var_type != proto::VarType::SELECTED_ROWS) {
6590
continue;
91+
}
6692
}
6793

6894
// compute op only runs in one device
@@ -93,12 +119,33 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
93119
if (ref_cnts.count(place.device) &&
94120
ref_cnts[place.device]->count(var_name)) {
95121
++(*ref_cnts[place.device])[var_name];
122+
123+
auto *next_compute_op = FindNextComputationOpHandle(var_handle);
124+
if (next_compute_op != nullptr) {
125+
if (compute_ref_cnt_map.count(next_compute_op)) {
126+
compute_ref_cnt_map[next_compute_op]->AddVar(var_name);
127+
VLOG(5) << "Add reference count of " << var_name << " to Operator "
128+
<< next_compute_op->Name();
129+
} else {
130+
// Create new reference_count_op_handle
131+
ir::Node *ref_cnt_node = graph->CreateEmptyNode(
132+
"reference_count", ir::Node::Type::kOperation);
133+
auto *ref_cnt_handle = new ReferenceCountOpHandle(
134+
ref_cnt_node, next_compute_op->GetScope(), place, {var_name},
135+
gcs[place.device].get(), cur_ref_cnts[place.device].get());
136+
if (next_compute_op->Outputs().empty()) {
137+
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
138+
next_compute_op->AddOutput(dep_var);
139+
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
140+
}
141+
ref_cnt_handle->AddInput(next_compute_op->Outputs().front());
142+
compute_ref_cnt_map[next_compute_op].reset(ref_cnt_handle);
143+
}
144+
}
96145
}
97146
}
98147
};
99148

100-
std::unordered_map<OpHandleBase *, ReferenceCountOpHandle *>
101-
compute_ref_cnt_map;
102149
auto &all_ops = graph->Get<GraphOps>(kGraphOps);
103150
for (auto &op : all_ops) {
104151
auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs());
@@ -113,11 +160,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
113160
auto *ref_cnt_handle = new ReferenceCountOpHandle(
114161
ref_cnt_node, compute_op->GetScope(), place, in_var_names,
115162
gcs[place.device].get(), cur_ref_cnts[place.device].get());
116-
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
117-
compute_op->AddOutput(dep_var);
118-
ref_cnt_handle->AddInput(dep_var);
119-
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
120-
compute_ref_cnt_map[compute_op] = ref_cnt_handle;
163+
if (compute_op->Outputs().empty()) {
164+
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
165+
compute_op->AddOutput(dep_var);
166+
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
167+
}
168+
ref_cnt_handle->AddInput(compute_op->Outputs().front());
169+
compute_ref_cnt_map[compute_op].reset(ref_cnt_handle);
121170
}
122171

123172
for (auto &op : all_ops) {
@@ -131,7 +180,11 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
131180
new_all_ops.emplace_back(std::move(op));
132181
auto it = compute_ref_cnt_map.find(new_all_ops.back().get());
133182
if (it != compute_ref_cnt_map.end()) {
134-
new_all_ops.emplace_back(it->second);
183+
// Add LeafNode to ReferenceCountOpHandle
184+
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
185+
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
186+
it->second->AddOutput(dummy_leaf);
187+
new_all_ops.emplace_back(std::move(it->second));
135188
}
136189
}
137190

0 commit comments

Comments
 (0)