Skip to content

Commit 0be1582

Browse files
authored
Merge pull request #13525 from reyoung/fix_mixed_vector
Fix mixed vector
2 parents 4e81e22 + e1913bc commit 0be1582

File tree

11 files changed

+416
-331
lines changed

11 files changed

+416
-331
lines changed

paddle/fluid/framework/details/cow_ptr.h

Lines changed: 19 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -20,79 +20,37 @@ namespace paddle {
2020
namespace framework {
2121
namespace details {
2222

23-
// Change it to thread safe flags if needed.
24-
class ThreadUnsafeOwnershipFlags {
23+
template <class T>
24+
class COWPtr {
2525
public:
26-
explicit ThreadUnsafeOwnershipFlags(bool flag) : flag_(flag) {}
27-
28-
ThreadUnsafeOwnershipFlags(const ThreadUnsafeOwnershipFlags& other) = delete;
29-
ThreadUnsafeOwnershipFlags& operator=(
30-
const ThreadUnsafeOwnershipFlags& other) = delete;
31-
ThreadUnsafeOwnershipFlags(ThreadUnsafeOwnershipFlags&& other) = default;
32-
33-
void SetOwnership(bool flag) { flag_ = flag; }
34-
35-
// Invoke the callback if it is not owned.
36-
template <typename Callback>
37-
void AcquireOwnershipOnce(Callback acquire) {
38-
if (!flag_) {
39-
acquire();
40-
flag_ = true;
41-
}
42-
}
26+
typedef std::shared_ptr<T> RefPtr;
4327

4428
private:
45-
bool flag_;
46-
};
29+
RefPtr m_sp;
4730

48-
// Copy-On-Write pointer.
49-
// It will hold a T* pointer, and only copy once when `MutableData` is invoked.
50-
//
51-
// The template parameter OwnershipFlags should have:
52-
// * a constructor takes a bool. True if own.
53-
// * SetOwnership(bool flag).
54-
// * AcquireOwnershipOnce(Callback). It will invoke the callback if it is not
55-
// owned.
56-
//
57-
// https://en.wikipedia.org/wiki/Copy-on-write
58-
template <typename T, typename OwnershipFlags = ThreadUnsafeOwnershipFlags>
59-
class COWPtr {
6031
public:
61-
// Ctor from raw pointer.
62-
explicit COWPtr(T* ptr) : payload_(ptr), ownership_{true} {}
32+
COWPtr() : m_sp(nullptr) {}
33+
explicit COWPtr(T* t) : m_sp(t) {}
6334

64-
// Move methods. Steal ownership from origin
65-
COWPtr(COWPtr&& other)
66-
: payload_(other.payload_), ownership_{std::move(other.ownership_)} {}
67-
COWPtr& operator=(COWPtr&& origin) = default;
35+
const T& Data() const { return *m_sp; }
6836

69-
// Copy methods. Not own payload
70-
COWPtr(const COWPtr& other) : payload_(other.payload_), ownership_{false} {}
71-
COWPtr& operator=(const COWPtr& other) {
72-
payload_ = other.payload_;
73-
ownership_.SetOwnership(false);
74-
return *this;
75-
}
76-
77-
// Access read only data.
78-
const T& Data() const { return *payload_; }
79-
80-
// Access mutable data. If the data is not owned, the data will be copied
81-
// before.
8237
T* MutableData() {
83-
ownership_.AcquireOwnershipOnce(
84-
[this] { payload_.reset(new T(*payload_)); });
85-
return payload_.get();
38+
DetachIfNotUnique();
39+
return m_sp.get();
8640
}
8741

88-
private:
89-
// Actual data pointer.
90-
std::shared_ptr<T> payload_;
42+
void DetachIfNotUnique() {
43+
T* tmp = m_sp.get();
44+
if (!(tmp == nullptr || m_sp.unique())) {
45+
Detach();
46+
}
47+
}
9148

92-
// Ownership flag.
93-
OwnershipFlags ownership_;
49+
void Detach() {
50+
T* tmp = m_sp.get();
51+
m_sp = RefPtr(new T(*tmp));
52+
}
9453
};
95-
9654
} // namespace details
9755
} // namespace framework
9856
} // namespace paddle

paddle/fluid/framework/details/cow_ptr_test.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@ TEST(COWPtr, all) {
3030
ASSERT_EQ(ptr2.Data(), 10);
3131
}
3232

33+
TEST(COWPtr, change_old) {
34+
COWPtr<int> ptr(new int{0});
35+
COWPtr<int> ptr2 = ptr;
36+
*ptr.MutableData() = 10;
37+
ASSERT_EQ(ptr2.Data(), 0);
38+
ASSERT_EQ(ptr.Data(), 10);
39+
}
40+
3341
} // namespace details
3442
} // namespace framework
3543
} // namespace paddle

0 commit comments

Comments
 (0)