Skip to content

Commit 6e3cc0c

Browse files
authored
Merge pull request #7240 from reyoung/feature/make_lod_a_share_ptr
Add COWPtr and its unittest
2 parents e3d296f + 3b0afae commit 6e3cc0c

File tree

3 files changed

+134
-1
lines changed

3 files changed

+134
-1
lines changed

paddle/framework/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,6 @@ cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece operat
7979
cc_test(init_test SRCS init_test.cc DEPS init)
8080

8181
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)
82-
82+
cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
8383
nv_test(device_data_transform_test SRCS device_data_transform_test.cu
8484
DEPS operator op_registry init math_function)

paddle/framework/details/cow_ptr.h

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <memory>
17+
#include <thread>
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace details {
22+
23+
// Change it to thread safe flags if needed.
24+
class ThreadUnsafeOwnershipFlags {
25+
public:
26+
ThreadUnsafeOwnershipFlags(bool flag) : flag_(flag) {}
27+
28+
ThreadUnsafeOwnershipFlags(const ThreadUnsafeOwnershipFlags& other) = delete;
29+
ThreadUnsafeOwnershipFlags& operator=(
30+
const ThreadUnsafeOwnershipFlags& other) = delete;
31+
ThreadUnsafeOwnershipFlags(ThreadUnsafeOwnershipFlags&& other) = default;
32+
33+
void SetOwnership(bool flag) { flag_ = flag; }
34+
35+
// Invoke the callback if it is not owned.
36+
template <typename Callback>
37+
void AcquireOwnershipOnce(Callback acquire) {
38+
if (!flag_) {
39+
acquire();
40+
flag_ = true;
41+
}
42+
}
43+
44+
private:
45+
bool flag_;
46+
};
47+
48+
// Copy-On-Write pointer.
49+
// It will hold a T* pointer, and only copy once when `MutableData` is invoked.
50+
//
51+
// The template parameter OwnershipFlags should have:
52+
// * a constructor takes a bool. True if own.
53+
// * SetOwnership(bool flag).
54+
// * AcquireOwnershipOnce(Callback). It will invoke the callback if it is not
55+
// owned.
56+
//
57+
// https://en.wikipedia.org/wiki/Copy-on-write
58+
template <typename T, typename OwnershipFlags = ThreadUnsafeOwnershipFlags>
59+
class COWPtr {
60+
public:
61+
// Ctor from raw pointer.
62+
explicit COWPtr(T* ptr) : payload_(ptr), ownership_{true} {}
63+
64+
// Move methods. Steal ownership from origin
65+
COWPtr(COWPtr&& other)
66+
: payload_(other.payload_), ownership_{std::move(other.ownership_)} {}
67+
COWPtr& operator=(COWPtr&& origin) = default;
68+
69+
// Copy methods. Not own payload
70+
COWPtr(const COWPtr& other) : payload_(other.payload_), ownership_{false} {}
71+
COWPtr& operator=(const COWPtr& other) {
72+
payload_ = other.payload_;
73+
ownership_.SetOwnership(false);
74+
return *this;
75+
}
76+
77+
// Access read only data.
78+
const T& Data() const { return *payload_; }
79+
80+
// Access mutable data. If the data is not owned, the data will be copied
81+
// before.
82+
T* MutableData() {
83+
ownership_.AcquireOwnershipOnce(
84+
[this] { payload_.reset(new T(*payload_)); });
85+
return payload_.get();
86+
}
87+
88+
private:
89+
// Actual data pointer.
90+
std::shared_ptr<T> payload_;
91+
92+
// Ownership flag.
93+
OwnershipFlags ownership_;
94+
};
95+
96+
} // namespace details
97+
} // namespace framework
98+
} // namespace paddle
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/framework/details/cow_ptr.h"
16+
#include "gtest/gtest.h"
17+
18+
namespace paddle {
19+
namespace framework {
20+
namespace details {
21+
22+
TEST(COWPtr, all) {
23+
COWPtr<int> ptr(new int{0});
24+
ASSERT_EQ(ptr.Data(), 0);
25+
COWPtr<int> ptr2 = ptr;
26+
ASSERT_EQ(ptr2.Data(), 0);
27+
ASSERT_EQ(&ptr2.Data(), &ptr.Data());
28+
*ptr2.MutableData() = 10;
29+
ASSERT_EQ(ptr.Data(), 0);
30+
ASSERT_EQ(ptr2.Data(), 10);
31+
}
32+
33+
} // namespace details
34+
} // namespace framework
35+
} // namespace paddle

0 commit comments

Comments
 (0)