@@ -22,41 +22,6 @@ namespace memory {
22
22
namespace allocation {
23
23
24
24
BufferedAllocator::BufferedAllocator (std::unique_ptr<Allocator>&& allocator) {
25
- std::vector<size_t > division_plan (8 * sizeof (size_t ));
26
- for (size_t i = 0 ; i < 8 * sizeof (size_t ); ++i) {
27
- division_plan[i] = (static_cast <size_t >(1 ) << i);
28
- }
29
- InitAndEnforceCheck (std::move (allocator), division_plan);
30
- }
31
-
32
- BufferedAllocator::BufferedAllocator (std::unique_ptr<Allocator>&& allocator,
33
- const std::vector<size_t >& division_plan) {
34
- InitAndEnforceCheck (std::move (allocator), division_plan);
35
- }
36
-
37
- BufferedAllocator::~BufferedAllocator () { FlushImpl (); }
38
-
39
- void BufferedAllocator::FlushImpl () {
40
- for (auto & v : allocations_) {
41
- for (auto & pair : v) {
42
- underlying_allocator_->FreeUniquePtr (std::move (pair.second ));
43
- }
44
- v.clear ();
45
- }
46
- }
47
-
48
- void BufferedAllocator::Flush () {
49
- if (mtx_) {
50
- std::lock_guard<std::mutex> lock (*mtx_);
51
- FlushImpl ();
52
- } else {
53
- FlushImpl ();
54
- }
55
- }
56
-
57
- void BufferedAllocator::InitAndEnforceCheck (
58
- std::unique_ptr<Allocator>&& allocator,
59
- const std::vector<size_t >& division_plan) {
60
25
underlying_allocator_.reset (
61
26
dynamic_cast <UnmanagedAllocator*>(allocator.release ()));
62
27
PADDLE_ENFORCE_NOT_NULL (
@@ -65,141 +30,54 @@ void BufferedAllocator::InitAndEnforceCheck(
65
30
if (underlying_allocator_->IsAllocThreadSafe ()) {
66
31
mtx_.reset (new std::mutex ());
67
32
}
68
- constexpr size_t kMax = std::numeric_limits<size_t >::max ();
69
- if (division_plan.empty ()) {
70
- division_plan_.assign ({0 , kMax });
71
- } else {
72
- auto from = division_plan.front () == 0 ? division_plan.begin () + 1
73
- : division_plan.begin ();
74
- auto to = division_plan.back () == kMax ? division_plan.end () - 1
75
- : division_plan.end ();
76
- division_plan_.reserve (to - from + 2 );
77
- division_plan_.push_back (0 );
78
- division_plan_.insert (division_plan_.end (), from, to);
79
- division_plan_.push_back (kMax );
80
- for (size_t i = 1 ; i < division_plan_.size (); ++i) {
81
- PADDLE_ENFORCE_LT (division_plan_[i - 1 ], division_plan_[i],
82
- " Division plan must be strictly sorted" );
83
- }
84
- }
85
- allocations_.resize (division_plan_.size () - 1 );
86
- }
87
-
88
- void BufferedAllocator::InsertAllocationImpl (
89
- std::unique_ptr<Allocation>&& allocation) {
90
- auto size = allocation->size ();
91
- auto idx = GetListIndex (size);
92
- allocations_[idx].emplace (size, std::move (allocation));
93
- }
94
-
95
- void BufferedAllocator::InsertAllocation (
96
- std::unique_ptr<Allocation>&& allocation) {
97
- if (mtx_) {
98
- std::lock_guard<std::mutex> lock (*mtx_);
99
- InsertAllocationImpl (std::move (allocation));
100
- } else {
101
- InsertAllocationImpl (std::move (allocation));
102
- }
103
33
}
104
34
105
- bool BufferedAllocator::Match (size_t actual_size, size_t requested_size) {
106
- return (actual_size >> 1 ) < requested_size;
107
- }
108
-
109
- size_t BufferedAllocator::GetListIndex (size_t size) {
110
- auto it =
111
- std::upper_bound (division_plan_.begin (), division_plan_.end (), size);
112
- return static_cast <size_t >(it - division_plan_.begin ()) - 1 ;
113
- }
35
+ BufferedAllocator::~BufferedAllocator () { FreeCache (-1UL ); }
114
36
115
- std::unique_ptr<Allocation> BufferedAllocator::RemoveAllocationImpl (
116
- size_t size) {
117
- auto idx = GetListIndex (size);
118
- auto & allocation_map = allocations_[idx];
119
- auto it = allocation_map.lower_bound (size);
120
- // Only remove allocation whose size is not more than twice of requested size
121
- if (it != allocation_map.end ()) {
122
- if (Match (it->second ->size (), size)) {
123
- auto ret = std::move (it->second );
124
- allocation_map.erase (it);
125
- return ret;
126
- } else {
127
- return nullptr ;
128
- }
129
- } else {
130
- while (++idx < allocations_.size () && Match (division_plan_[idx], size)) {
131
- auto & allocation_map = allocations_[idx];
132
- if (!allocation_map.empty ()) {
133
- auto it = allocation_map.begin ();
134
- if (Match (it->second ->size (), size)) {
135
- auto ret = std::move (it->second );
136
- allocation_map.erase (it);
137
- return ret;
138
- } else {
139
- return nullptr ;
140
- }
141
- }
37
+ std::unique_ptr<Allocation> BufferedAllocator::Allocate (size_t size,
38
+ Allocator::Attr attr) {
39
+ std::unique_ptr<Allocation> result;
40
+ {
41
+ platform::LockGuardPtr<std::mutex> guard (mtx_);
42
+ auto it = allocations_.lower_bound (size);
43
+ if (it != allocations_.end () && it->first < size * 2 ) {
44
+ result = std::move (it->second );
45
+ allocations_.erase (it);
142
46
}
143
- return nullptr ;
144
47
}
145
- }
146
48
147
- std::unique_ptr<Allocation> BufferedAllocator::RemoveAllocation (size_t size) {
148
- if (mtx_) {
149
- std::lock_guard<std::mutex> lock (*mtx_);
150
- return RemoveAllocationImpl (size);
151
- } else {
152
- return RemoveAllocationImpl (size);
49
+ if (result) {
50
+ return result;
153
51
}
154
- }
155
52
156
- std::unique_ptr<Allocation> BufferedAllocator::Allocate (size_t size,
157
- Allocator::Attr attr) {
158
- auto ret = RemoveAllocation (size);
159
- if (!ret) {
160
- try {
161
- return underlying_allocator_->Allocate (size, attr);
162
- } catch (BadAlloc&) {
163
- // if allocation failed, try to free some memorys from buffers
164
- FreeAllocations (size);
165
- return underlying_allocator_->Allocate (size, attr);
166
- }
53
+ try {
54
+ return underlying_allocator_->Allocate (size, attr);
55
+ } catch (BadAlloc&) {
56
+ FreeCache (size);
57
+ return underlying_allocator_->Allocate (size, attr);
167
58
}
168
- return ret;
169
59
}
170
60
171
- void BufferedAllocator::FreeAllocationsImpl (size_t size) {
61
+ void BufferedAllocator::FreeCache (size_t size) {
62
+ platform::LockGuardPtr<std::mutex> guard (mtx_);
172
63
if (UNLIKELY (size == 0 )) return ;
173
64
size_t cur = 0 ;
174
- for (auto & alloc_map : allocations_) {
175
- // use reverse iterator to free large allocations first
176
- while (!alloc_map.empty ()) {
177
- auto it = --(alloc_map.end ());
178
- cur += it->second ->size ();
179
- underlying_allocator_->FreeUniquePtr (std::move (it->second ));
180
- alloc_map.erase (it);
181
- if (cur >= size) return ;
182
- }
183
- }
184
- }
185
-
186
- void BufferedAllocator::FreeAllocations (size_t size) {
187
- if (mtx_) {
188
- std::lock_guard<std::mutex> lock (*mtx_);
189
- FreeAllocationsImpl (size);
190
- } else {
191
- FreeAllocationsImpl (size);
65
+ while (!allocations_.empty ()) { // free the largest
66
+ auto it = --allocations_.end ();
67
+ cur += it->second ->size ();
68
+ underlying_allocator_->FreeUniquePtr (std::move (it->second ));
69
+ allocations_.erase (it);
70
+ if (cur >= size) return ;
192
71
}
193
72
}
194
73
195
74
void BufferedAllocator::FreeUniquePtr (std::unique_ptr<Allocation> allocation) {
196
- InsertAllocation (std::move (allocation));
75
+ platform::LockGuardPtr<std::mutex> guard (mtx_);
76
+ allocations_.emplace (allocation->size (), std::move (allocation));
197
77
}
198
78
199
- bool BufferedAllocator::IsAllocThreadSafe () const { return mtx_ != nullptr ; }
200
-
201
- const std::vector<size_t >& BufferedAllocator::GetDivisionPlan () const {
202
- return division_plan_;
79
+ bool BufferedAllocator::IsAllocThreadSafe () const {
80
+ return this ->underlying_allocator_ ->IsAllocThreadSafe ();
203
81
}
204
82
205
83
} // namespace allocation
0 commit comments