@@ -20,79 +20,37 @@ namespace paddle {
20
20
namespace framework {
21
21
namespace details {
22
22
23
- // Change it to thread safe flags if needed.
24
- class ThreadUnsafeOwnershipFlags {
23
+ template < class T >
24
+ class COWPtr {
25
25
public:
26
- explicit ThreadUnsafeOwnershipFlags (bool flag) : flag_(flag) {}
27
-
28
- ThreadUnsafeOwnershipFlags (const ThreadUnsafeOwnershipFlags& other) = delete ;
29
- ThreadUnsafeOwnershipFlags& operator =(
30
- const ThreadUnsafeOwnershipFlags& other) = delete ;
31
- ThreadUnsafeOwnershipFlags (ThreadUnsafeOwnershipFlags&& other) = default ;
32
-
33
- void SetOwnership (bool flag) { flag_ = flag; }
34
-
35
- // Invoke the callback if it is not owned.
36
- template <typename Callback>
37
- void AcquireOwnershipOnce (Callback acquire) {
38
- if (!flag_) {
39
- acquire ();
40
- flag_ = true ;
41
- }
42
- }
26
+ typedef std::shared_ptr<T> RefPtr;
43
27
44
28
private:
45
- bool flag_;
46
- };
29
+ RefPtr m_sp;
47
30
48
- // Copy-On-Write pointer.
49
- // It will hold a T* pointer, and only copy once when `MutableData` is invoked.
50
- //
51
- // The template parameter OwnershipFlags should have:
52
- // * a constructor takes a bool. True if own.
53
- // * SetOwnership(bool flag).
54
- // * AcquireOwnershipOnce(Callback). It will invoke the callback if it is not
55
- // owned.
56
- //
57
- // https://en.wikipedia.org/wiki/Copy-on-write
58
- template <typename T, typename OwnershipFlags = ThreadUnsafeOwnershipFlags>
59
- class COWPtr {
60
31
public:
61
- // Ctor from raw pointer.
62
- explicit COWPtr (T* ptr ) : payload_(ptr), ownership_{ true } {}
32
+ COWPtr () : m_sp( nullptr ) {}
33
+ explicit COWPtr (T* t ) : m_sp(t) {}
63
34
64
- // Move methods. Steal ownership from origin
65
- COWPtr (COWPtr&& other)
66
- : payload_(other.payload_), ownership_{std::move (other.ownership_ )} {}
67
- COWPtr& operator =(COWPtr&& origin) = default ;
35
+ const T& Data () const { return *m_sp; }
68
36
69
- // Copy methods. Not own payload
70
- COWPtr (const COWPtr& other) : payload_(other.payload_), ownership_{false } {}
71
- COWPtr& operator =(const COWPtr& other) {
72
- payload_ = other.payload_ ;
73
- ownership_.SetOwnership (false );
74
- return *this ;
75
- }
76
-
77
- // Access read only data.
78
- const T& Data () const { return *payload_; }
79
-
80
- // Access mutable data. If the data is not owned, the data will be copied
81
- // before.
82
37
T* MutableData () {
83
- ownership_.AcquireOwnershipOnce (
84
- [this ] { payload_.reset (new T (*payload_)); });
85
- return payload_.get ();
38
+ DetachIfNotUnique ();
39
+ return m_sp.get ();
86
40
}
87
41
88
- private:
89
- // Actual data pointer.
90
- std::shared_ptr<T> payload_;
42
+ void DetachIfNotUnique () {
43
+ T* tmp = m_sp.get ();
44
+ if (!(tmp == nullptr || m_sp.unique ())) {
45
+ Detach ();
46
+ }
47
+ }
91
48
92
- // Ownership flag.
93
- OwnershipFlags ownership_;
49
+ void Detach () {
50
+ T* tmp = m_sp.get ();
51
+ m_sp = RefPtr (new T (*tmp));
52
+ }
94
53
};
95
-
96
54
} // namespace details
97
55
} // namespace framework
98
56
} // namespace paddle
0 commit comments