@@ -20,79 +20,41 @@ namespace paddle {
20
20
namespace framework {
21
21
namespace details {
22
22
23
- // Change it to thread safe flags if needed.
24
- class ThreadUnsafeOwnershipFlags {
23
+ template < class T >
24
+ class COWPtr {
25
25
public:
26
- explicit ThreadUnsafeOwnershipFlags (bool flag) : flag_(flag) {}
27
-
28
- ThreadUnsafeOwnershipFlags (const ThreadUnsafeOwnershipFlags& other) = delete ;
29
- ThreadUnsafeOwnershipFlags& operator =(
30
- const ThreadUnsafeOwnershipFlags& other) = delete ;
31
- ThreadUnsafeOwnershipFlags (ThreadUnsafeOwnershipFlags&& other) = default ;
26
+ typedef std::shared_ptr<T> RefPtr;
32
27
33
- void SetOwnership (bool flag) { flag_ = flag; }
28
+ private:
29
+ RefPtr m_sp;
34
30
35
- // Invoke the callback if it is not owned.
36
- template <typename Callback>
37
- void AcquireOwnershipOnce (Callback acquire) {
38
- if (!flag_) {
39
- acquire ();
40
- flag_ = true ;
31
+ void detach () {
32
+ T* tmp = m_sp.get ();
33
+ if (!(tmp == nullptr || m_sp.unique ())) {
34
+ m_sp = RefPtr (new T (*tmp));
41
35
}
42
36
}
43
37
44
- private:
45
- bool flag_;
46
- };
47
-
48
- // Copy-On-Write pointer.
49
- // It will hold a T* pointer, and only copy once when `MutableData` is invoked.
50
- //
51
- // The template parameter OwnershipFlags should have:
52
- // * a constructor takes a bool. True if own.
53
- // * SetOwnership(bool flag).
54
- // * AcquireOwnershipOnce(Callback). It will invoke the callback if it is not
55
- // owned.
56
- //
57
- // https://en.wikipedia.org/wiki/Copy-on-write
58
- template <typename T, typename OwnershipFlags = ThreadUnsafeOwnershipFlags>
59
- class COWPtr {
60
38
public:
61
- // Ctor from raw pointer.
62
- explicit COWPtr (T* ptr) : payload_(ptr), ownership_{true } {}
39
+ COWPtr () : m_sp(nullptr ) {}
40
+ explicit COWPtr (T* t) : m_sp(t) {}
41
+ explicit COWPtr (const RefPtr& refptr) : m_sp(refptr) {}
63
42
64
- // Move methods. Steal ownership from origin
65
- COWPtr (COWPtr&& other)
66
- : payload_(other.payload_), ownership_{std::move (other.ownership_ )} {}
67
- COWPtr& operator =(COWPtr&& origin) = default ;
43
+ const T& Data () const { return operator *(); }
68
44
69
- // Copy methods. Not own payload
70
- COWPtr (const COWPtr& other) : payload_(other.payload_), ownership_{false } {}
71
- COWPtr& operator =(const COWPtr& other) {
72
- payload_ = other.payload_ ;
73
- ownership_.SetOwnership (false );
74
- return *this ;
75
- }
76
-
77
- // Access read only data.
78
- const T& Data () const { return *payload_; }
45
+ T* MutableData () { return operator ->(); }
79
46
80
- // Access mutable data. If the data is not owned, the data will be copied
81
- // before.
82
- T* MutableData () {
83
- ownership_.AcquireOwnershipOnce (
84
- [this ] { payload_.reset (new T (*payload_)); });
85
- return payload_.get ();
47
+ const T& operator *() const { return *m_sp; }
48
+ T& operator *() {
49
+ detach ();
50
+ return *m_sp;
51
+ }
52
+ const T* operator ->() const { return m_sp.operator ->(); }
53
+ T* operator ->() {
54
+ detach ();
55
+ return m_sp.operator ->();
86
56
}
87
-
88
- private:
89
- // Actual data pointer.
90
- std::shared_ptr<T> payload_;
91
-
92
- // Ownership flag.
93
- OwnershipFlags ownership_;
94
57
};
95
-
96
58
} // namespace details
97
59
} // namespace framework
98
60
} // namespace paddle
0 commit comments