@@ -20,41 +20,79 @@ namespace paddle {
20
20
namespace framework {
21
21
namespace details {
22
22
23
- template < class T >
24
- class COWPtr {
23
+ // Change it to thread safe flags if needed.
24
+ class ThreadUnsafeOwnershipFlags {
25
25
public:
26
- typedef std::shared_ptr<T> RefPtr;
26
+ explicit ThreadUnsafeOwnershipFlags ( bool flag) : flag_(flag) {}
27
27
28
- private:
29
- RefPtr m_sp;
28
+ ThreadUnsafeOwnershipFlags (const ThreadUnsafeOwnershipFlags& other) = delete ;
29
+ ThreadUnsafeOwnershipFlags& operator =(
30
+ const ThreadUnsafeOwnershipFlags& other) = delete ;
31
+ ThreadUnsafeOwnershipFlags (ThreadUnsafeOwnershipFlags&& other) = default ;
30
32
31
- void detach () {
32
- T* tmp = m_sp.get ();
33
- if (!(tmp == nullptr || m_sp.unique ())) {
34
- m_sp = RefPtr (new T (*tmp));
33
+ void SetOwnership (bool flag) { flag_ = flag; }
34
+
35
+ // Invoke the callback if it is not owned.
36
+ template <typename Callback>
37
+ void AcquireOwnershipOnce (Callback acquire) {
38
+ if (!flag_) {
39
+ acquire ();
40
+ flag_ = true ;
35
41
}
36
42
}
37
43
38
- public:
39
- COWPtr () : m_sp(nullptr ) {}
40
- explicit COWPtr (T* t) : m_sp(t) {}
41
- explicit COWPtr (const RefPtr& refptr) : m_sp(refptr) {}
44
+ private:
45
+ bool flag_;
46
+ };
42
47
43
- const T& Data () const { return operator *(); }
48
+ // Copy-On-Write pointer.
49
+ // It will hold a T* pointer, and only copy once when `MutableData` is invoked.
50
+ //
51
+ // The template parameter OwnershipFlags should have:
52
+ // * a constructor takes a bool. True if own.
53
+ // * SetOwnership(bool flag).
54
+ // * AcquireOwnershipOnce(Callback). It will invoke the callback if it is not
55
+ // owned.
56
+ //
57
+ // https://en.wikipedia.org/wiki/Copy-on-write
58
+ template <typename T, typename OwnershipFlags = ThreadUnsafeOwnershipFlags>
59
+ class COWPtr {
60
+ public:
61
+ // Ctor from raw pointer.
62
+ explicit COWPtr (T* ptr) : payload_(ptr), ownership_{true } {}
44
63
45
- T* MutableData () { return operator ->(); }
64
+ // Move methods. Steal ownership from origin
65
+ COWPtr (COWPtr&& other)
66
+ : payload_(other.payload_), ownership_{std::move (other.ownership_ )} {}
67
+ COWPtr& operator =(COWPtr&& origin) = default ;
46
68
47
- const T& operator *() const { return *m_sp; }
48
- T& operator *() {
49
- detach ();
50
- return *m_sp;
69
+ // Copy methods. Not own payload
70
+ COWPtr (const COWPtr& other) : payload_(other.payload_), ownership_{false } {}
71
+ COWPtr& operator =(const COWPtr& other) {
72
+ payload_ = other.payload_ ;
73
+ ownership_.SetOwnership (false );
74
+ return *this ;
51
75
}
52
- const T* operator ->() const { return m_sp.operator ->(); }
53
- T* operator ->() {
54
- detach ();
55
- return m_sp.operator ->();
76
+
77
+ // Access read only data.
78
+ const T& Data () const { return *payload_; }
79
+
80
+ // Access mutable data. If the data is not owned, the data will be copied
81
+ // before.
82
+ T* MutableData () {
83
+ ownership_.AcquireOwnershipOnce (
84
+ [this ] { payload_.reset (new T (*payload_)); });
85
+ return payload_.get ();
56
86
}
87
+
88
+ private:
89
+ // Actual data pointer.
90
+ std::shared_ptr<T> payload_;
91
+
92
+ // Ownership flag.
93
+ OwnershipFlags ownership_;
57
94
};
95
+
58
96
} // namespace details
59
97
} // namespace framework
60
98
} // namespace paddle
0 commit comments