Skip to content

Commit 3043f51

Browse files
authored
Merge pull request #13511 from reyoung/fix_ce
Revert "Merge pull request #13431 from chengduoZH/refine_lod"
2 parents cffad81 + a6c8d6b commit 3043f51

File tree

9 files changed

+326
-381
lines changed

9 files changed

+326
-381
lines changed

paddle/fluid/framework/details/cow_ptr.h

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,41 +20,79 @@ namespace paddle {
2020
namespace framework {
2121
namespace details {
2222

23-
template <class T>
24-
class COWPtr {
23+
// Change it to thread safe flags if needed.
24+
class ThreadUnsafeOwnershipFlags {
2525
public:
26-
typedef std::shared_ptr<T> RefPtr;
26+
explicit ThreadUnsafeOwnershipFlags(bool flag) : flag_(flag) {}
2727

28-
private:
29-
RefPtr m_sp;
28+
ThreadUnsafeOwnershipFlags(const ThreadUnsafeOwnershipFlags& other) = delete;
29+
ThreadUnsafeOwnershipFlags& operator=(
30+
const ThreadUnsafeOwnershipFlags& other) = delete;
31+
ThreadUnsafeOwnershipFlags(ThreadUnsafeOwnershipFlags&& other) = default;
3032

31-
void detach() {
32-
T* tmp = m_sp.get();
33-
if (!(tmp == nullptr || m_sp.unique())) {
34-
m_sp = RefPtr(new T(*tmp));
33+
void SetOwnership(bool flag) { flag_ = flag; }
34+
35+
// Invoke the callback if it is not owned.
36+
template <typename Callback>
37+
void AcquireOwnershipOnce(Callback acquire) {
38+
if (!flag_) {
39+
acquire();
40+
flag_ = true;
3541
}
3642
}
3743

38-
public:
39-
COWPtr() : m_sp(nullptr) {}
40-
explicit COWPtr(T* t) : m_sp(t) {}
41-
explicit COWPtr(const RefPtr& refptr) : m_sp(refptr) {}
44+
private:
45+
bool flag_;
46+
};
4247

43-
const T& Data() const { return operator*(); }
48+
// Copy-On-Write pointer.
49+
// It will hold a T* pointer, and only copy once when `MutableData` is invoked.
50+
//
51+
// The template parameter OwnershipFlags should have:
52+
// * a constructor takes a bool. True if own.
53+
// * SetOwnership(bool flag).
54+
// * AcquireOwnershipOnce(Callback). It will invoke the callback if it is not
55+
// owned.
56+
//
57+
// https://en.wikipedia.org/wiki/Copy-on-write
58+
template <typename T, typename OwnershipFlags = ThreadUnsafeOwnershipFlags>
59+
class COWPtr {
60+
public:
61+
// Ctor from raw pointer.
62+
explicit COWPtr(T* ptr) : payload_(ptr), ownership_{true} {}
4463

45-
T* MutableData() { return operator->(); }
64+
// Move methods. Steal ownership from origin
65+
COWPtr(COWPtr&& other)
66+
: payload_(other.payload_), ownership_{std::move(other.ownership_)} {}
67+
COWPtr& operator=(COWPtr&& origin) = default;
4668

47-
const T& operator*() const { return *m_sp; }
48-
T& operator*() {
49-
detach();
50-
return *m_sp;
69+
// Copy methods. Not own payload
70+
COWPtr(const COWPtr& other) : payload_(other.payload_), ownership_{false} {}
71+
COWPtr& operator=(const COWPtr& other) {
72+
payload_ = other.payload_;
73+
ownership_.SetOwnership(false);
74+
return *this;
5175
}
52-
const T* operator->() const { return m_sp.operator->(); }
53-
T* operator->() {
54-
detach();
55-
return m_sp.operator->();
76+
77+
// Access read only data.
78+
const T& Data() const { return *payload_; }
79+
80+
// Access mutable data. If the data is not owned, the data will be copied
81+
// before.
82+
T* MutableData() {
83+
ownership_.AcquireOwnershipOnce(
84+
[this] { payload_.reset(new T(*payload_)); });
85+
return payload_.get();
5686
}
87+
88+
private:
89+
// Actual data pointer.
90+
std::shared_ptr<T> payload_;
91+
92+
// Ownership flag.
93+
OwnershipFlags ownership_;
5794
};
95+
5896
} // namespace details
5997
} // namespace framework
6098
} // namespace paddle

paddle/fluid/framework/details/cow_ptr_test.cc

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,6 @@ TEST(COWPtr, all) {
3030
ASSERT_EQ(ptr2.Data(), 10);
3131
}
3232

33-
TEST(COWPtr, change_old) {
34-
COWPtr<int> ptr(new int{0});
35-
COWPtr<int> ptr2 = ptr;
36-
*ptr.MutableData() = 10;
37-
ASSERT_EQ(ptr2.Data(), 0);
38-
ASSERT_EQ(ptr.Data(), 10);
39-
}
40-
4133
} // namespace details
4234
} // namespace framework
4335
} // namespace paddle

0 commit comments

Comments
 (0)