Skip to content

Commit 0cfb546

Browse files
committed
Add COWPtr and its unittest
It will be used for LoD information in LoDTensor since LoD is a copy on write field. It is pretty slow for copying LoD information between operators. For resnet it will cost roughly 10% time of whole time, including reading data.
1 parent 5911644 commit 0cfb546

File tree

3 files changed

+134
-0
lines changed

3 files changed

+134
-0
lines changed

paddle/framework/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,4 @@ cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece)
7878
cc_test(init_test SRCS init_test.cc DEPS init)
7979

8080
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)
81+
cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)

paddle/framework/details/cow_ptr.h

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <memory>
17+
#include <thread>
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace details {
22+
23+
// Change it to thread safe flags if needed.
24+
class ThreadUnsafeOwnershipFlags {
25+
public:
26+
ThreadUnsafeOwnershipFlags(bool flag) : flag_(flag) {}
27+
28+
ThreadUnsafeOwnershipFlags(const ThreadUnsafeOwnershipFlags& o) = delete;
29+
ThreadUnsafeOwnershipFlags& operator=(const ThreadUnsafeOwnershipFlags& o) =
30+
delete;
31+
ThreadUnsafeOwnershipFlags(ThreadUnsafeOwnershipFlags&& o) = default;
32+
33+
void SetOwnership(bool flag) { flag_ = flag; }
34+
35+
template <typename Callback>
36+
void AcquireOwnershipOnce(Callback acquire) {
37+
if (!flag_) {
38+
acquire();
39+
flag_ = true;
40+
}
41+
}
42+
43+
private:
44+
bool flag_;
45+
};
46+
47+
// Copy On Write pointer.
48+
// It will hold a T* pointer, and only copy once when `MutableData` is invoked.
49+
//
50+
// The template parameter OwnershipFlags should have:
51+
// * a constructor takes a bool. True if own.
52+
// * SetOwnership(bool flag).
53+
// * AcquireOwnershipOnce(Callback). It will invoke the callback if it is not
54+
// owned.
55+
template <typename T, typename OwnershipFlags = ThreadUnsafeOwnershipFlags>
56+
class COWPtr {
57+
public:
58+
// Ctor from raw pointer.
59+
explicit COWPtr(T* ptr) : payload_(ptr), ownership_{true} {}
60+
61+
// Move methods. Steal ownership from origin
62+
COWPtr(COWPtr&& o)
63+
: payload_(o.payload_), ownership_{std::move(o.ownership_)} {}
64+
COWPtr& operator=(COWPtr&& origin) = default;
65+
66+
// Copy methods. Not own payload
67+
COWPtr(const COWPtr& o) : payload_(o.payload_), ownership_{false} {}
68+
COWPtr& operator=(const COWPtr& o) {
69+
payload_ = o.payload_;
70+
ownership_.SetOwnership(false);
71+
return *this;
72+
}
73+
74+
const T& Data() const { return *payload_; }
75+
76+
T* MutableData() {
77+
ownership_.AcquireOwnershipOnce(
78+
[this] { payload_.reset(new T(*payload_)); });
79+
return payload_.get();
80+
}
81+
82+
void Reset() {
83+
ownership_.AcquireOwnershipOnce([this] { payload_.reset(); });
84+
payload_.reset(new T());
85+
}
86+
87+
private:
88+
std::shared_ptr<T> payload_;
89+
OwnershipFlags ownership_;
90+
};
91+
92+
} // namespace details
93+
} // namespace framework
94+
} // namespace paddle
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/framework/details/cow_ptr.h"
16+
#include "gtest/gtest.h"
17+
18+
namespace paddle {
19+
namespace framework {
20+
namespace details {
21+
22+
TEST(COWPtr, all) {
23+
COWPtr<int> ptr(new int{0});
24+
ASSERT_EQ(ptr.Data(), 0);
25+
COWPtr<int> ptr2 = ptr;
26+
ASSERT_EQ(ptr2.Data(), 0);
27+
ASSERT_EQ(&ptr2.Data(), &ptr.Data());
28+
*ptr2.MutableData() = 10;
29+
ASSERT_EQ(ptr.Data(), 0);
30+
ASSERT_EQ(ptr2.Data(), 10);
31+
32+
auto ptr_before = ptr2.MutableData();
33+
ptr2.Reset();
34+
ASSERT_NE(ptr2.MutableData(), ptr_before);
35+
}
36+
37+
} // namespace details
38+
} // namespace framework
39+
} // namespace paddle

0 commit comments

Comments
 (0)