Skip to content

Commit ea81f8e

Browse files
committed
Clean interface of allocator
Clean managed/umnamaged allocator
1 parent 0263196 commit ea81f8e

26 files changed

+347
-585
lines changed

paddle/fluid/memory/allocation/CMakeLists.txt

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,6 @@ else()
2929
cpu_allocator)
3030
endif()
3131

32-
33-
cc_library(naive_managed_allocator SRCS naive_managed_allocator.cc DEPS allocator)
34-
cc_test(naive_managed_allocator_test SRCS naive_managed_allocator_test.cc DEPS naive_managed_allocator)
3532
nv_library(pinned_allocator SRCS pinned_allocator.cc DEPS allocator)
3633
if (WITH_GPU)
3734
set(AllocatorFacadeDeps gpu_info cuda_allocator pinned_allocator cuda_device_guard)
@@ -49,7 +46,6 @@ cc_library(allocator_facade SRCS allocator_facade.cc DEPS
4946
cpu_allocator
5047
locked_allocator
5148
best_fit_allocator
52-
naive_managed_allocator
5349
aligned_allocator
5450
auto_increment_allocator
5551
zero_size_allocator
@@ -61,6 +57,6 @@ cc_library(allocator_facade SRCS allocator_facade.cc DEPS
6157

6258
nv_test(allocation_and_eigen_test SRCS allocation_and_eigen_test.cu DEPS allocator_facade)
6359

64-
cc_test(retry_allocator_test SRCS retry_allocator_test.cc DEPS retry_allocator naive_managed_allocator best_fit_allocator locked_allocator cpu_allocator)
60+
cc_test(retry_allocator_test SRCS retry_allocator_test.cc DEPS retry_allocator best_fit_allocator locked_allocator cpu_allocator)
6561

6662
cc_test(allocator_facade_test SRCS allocator_facade_test.cc DEPS allocator_facade)

paddle/fluid/memory/allocation/aligned_allocator.cc

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,9 @@ namespace memory {
1919
namespace allocation {
2020

2121
ThinAlignedAllocator::ThinAlignedAllocator(
22-
std::shared_ptr<ManagedAllocator> underlyning_allocator)
22+
std::shared_ptr<Allocator> underlyning_allocator)
2323
: underlying_allocator_(std::move(underlyning_allocator)) {}
2424

25-
std::shared_ptr<Allocation> ThinAlignedAllocator::AllocateShared(
26-
size_t size, Allocator::Attr attr) {
27-
return std::shared_ptr<Allocation>(Allocate(size, attr).release());
28-
}
29-
3025
bool ThinAlignedAllocator::IsAllocThreadSafe() const {
3126
return underlying_allocator_->IsAllocThreadSafe();
3227
}

paddle/fluid/memory/allocation/aligned_allocator.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,15 @@ class AlignedAllocation : public Allocation {
7070
//
7171
// NOTE(yy): This could be an over design. If it harms readability of code, it
7272
// could be removed later.
73-
class ThinAlignedAllocator : public ManagedAllocator {
73+
class ThinAlignedAllocator : public Allocator {
7474
public:
7575
explicit ThinAlignedAllocator(
76-
std::shared_ptr<ManagedAllocator> underlyning_allocator);
77-
78-
std::shared_ptr<Allocation> AllocateShared(size_t size, Attr attr) override;
76+
std::shared_ptr<Allocator> underlyning_allocator);
7977

8078
bool IsAllocThreadSafe() const;
8179

8280
protected:
83-
std::shared_ptr<ManagedAllocator> underlying_allocator_;
81+
std::shared_ptr<Allocator> underlying_allocator_;
8482
};
8583

8684
// An aligned allocator will allocate `size+kAlignment` allocation and adjust

paddle/fluid/memory/allocation/allocator.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ bool Allocator::IsAllocThreadSafe() const { return false; }
2424

2525
const char* BadAlloc::what() const noexcept { return msg_.c_str(); }
2626

27+
MannualFreeAllocation::~MannualFreeAllocation() { allocator_->Free(this); }
28+
std::unique_ptr<Allocation> MannualFreeAllocator::Allocate(
29+
size_t size, Allocator::Attr attr) {
30+
return std::unique_ptr<Allocation>(AllocateImpl(size, attr));
31+
}
2732
} // namespace allocation
2833
} // namespace memory
2934
} // namespace paddle

paddle/fluid/memory/allocation/allocator.h

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -121,19 +121,30 @@ class Allocator {
121121
virtual bool IsAllocThreadSafe() const;
122122
};
123123

124-
// User need to invoke `Free` or `FreeUniquePtr` manually if allocated by
125-
// a manally managed allocator.
126-
class UnmanagedAllocator : public Allocator {
124+
class MannualFreeAllocator;
125+
class MannualFreeAllocation : public Allocation {
127126
public:
128-
virtual void FreeUniquePtr(std::unique_ptr<Allocation> allocation) = 0;
127+
MannualFreeAllocation(MannualFreeAllocator* allocator, void* ptr, size_t size,
128+
platform::Place place)
129+
: Allocation(ptr, size, place), allocator_(allocator) {}
130+
131+
~MannualFreeAllocation();
132+
133+
private:
134+
MannualFreeAllocator* allocator_;
129135
};
130136

131-
// The allocation will be managed by smart pointers. i.e., users do not need
132-
// to free allocation manually.
133-
class ManagedAllocator : public Allocator {
137+
// User need to invoke `Free` or `FreeUniquePtr` manually if allocated by
138+
// a manally managed allocator.
139+
class MannualFreeAllocator : public Allocator {
134140
public:
135-
virtual std::shared_ptr<Allocation> AllocateShared(
136-
size_t size, Allocator::Attr attr = kDefault) = 0;
141+
std::unique_ptr<Allocation> Allocate(size_t size, Attr attr) final;
142+
143+
protected:
144+
virtual void Free(MannualFreeAllocation* allocation) = 0;
145+
virtual MannualFreeAllocation* AllocateImpl(size_t size,
146+
Allocator::Attr attr) = 0;
147+
friend class MannualFreeAllocation;
137148
};
138149

139150
} // namespace allocation

paddle/fluid/memory/allocation/allocator_facade.cc

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
#include "paddle/fluid/memory/allocation/conditional_allocator.h"
2525
#include "paddle/fluid/memory/allocation/cpu_allocator.h"
2626
#include "paddle/fluid/memory/allocation/locked_allocator.h"
27-
#include "paddle/fluid/memory/allocation/naive_managed_allocator.h"
2827
#include "paddle/fluid/memory/allocation/retry_allocator.h"
2928
#include "paddle/fluid/memory/allocation/zero_size_allocator.h"
3029
#include "paddle/fluid/platform/cpu_info.h"
@@ -46,34 +45,28 @@ namespace memory {
4645
namespace allocation {
4746

4847
// TODO(yy): Dirty code here. This class should be configurable in runtime.
49-
class CPUManagedAllocator : public ManagedAllocator {
48+
class CPUManagedAllocator : public Allocator {
5049
public:
51-
CPUManagedAllocator()
52-
: normal_allocator_(NaiveManagedAllocator::Create(
53-
std::unique_ptr<Allocator>(new CPUAllocator()))) {}
50+
CPUManagedAllocator() : normal_allocator_(new CPUAllocator()) {}
5451

5552
std::unique_ptr<Allocation> Allocate(size_t size, Attr attr) override {
5653
return normal_allocator_->Allocate(size, attr);
5754
}
5855

59-
std::shared_ptr<Allocation> AllocateShared(size_t size, Attr attr) override {
60-
return normal_allocator_->AllocateShared(size, attr);
61-
}
62-
6356
bool IsAllocThreadSafe() const override { return true; }
6457

6558
private:
66-
std::shared_ptr<ManagedAllocator> normal_allocator_;
59+
std::shared_ptr<Allocator> normal_allocator_;
6760
};
6861

6962
// TODO(yy): Dirty code here. This class should be configurable in runtime.
70-
class ChunkedManagedAllocator : public ManagedAllocator {
63+
class ChunkedManagedAllocator : public Allocator {
7164
public:
7265
explicit ChunkedManagedAllocator(std::unique_ptr<Allocator> system_allocator,
7366
size_t max_chunk_size, size_t capacity = 1,
7467
int64_t retry_time = -1)
7568
: max_chunk_size_(max_chunk_size), retry_time_(retry_time) {
76-
raw_allocator_ = NaiveManagedAllocator::Create(std::move(system_allocator));
69+
raw_allocator_ = std::move(system_allocator);
7770

7871
if (max_chunk_size_ == 0) {
7972
default_allocator_ = raw_allocator_;
@@ -114,11 +107,7 @@ class ChunkedManagedAllocator : public ManagedAllocator {
114107
return default_allocator_->Allocate(size, attr);
115108
}
116109

117-
std::shared_ptr<Allocation> AllocateShared(size_t size, Attr attr) override {
118-
return default_allocator_->AllocateShared(size, attr);
119-
}
120-
121-
std::shared_ptr<ManagedAllocator> BestFitAllocatorCreator() {
110+
std::shared_ptr<Allocator> BestFitAllocatorCreator() {
122111
chunks_.emplace_back(raw_allocator_->Allocate(max_chunk_size_));
123112
auto* allocation = chunks_.back().get();
124113
std::unique_ptr<Allocator> unmanaged_allocator(new LockedAllocator(
@@ -127,12 +116,13 @@ class ChunkedManagedAllocator : public ManagedAllocator {
127116
if (retry_time_ <= 0) {
128117
VLOG(10) << "Create NaiveManagedAllocator without retry";
129118
return std::make_shared<AlignedAllocator<64u>>(
130-
NaiveManagedAllocator::Create(std::move(unmanaged_allocator)));
119+
std::move(unmanaged_allocator));
131120
} else {
132121
VLOG(10) << "Create RetryAllocator with retry_time " << retry_time_
133122
<< "ms";
134-
return std::make_shared<AlignedAllocator<64u>>(RetryAllocator::Create(
135-
std::move(unmanaged_allocator), static_cast<size_t>(retry_time_)));
123+
auto tmp = std::make_shared<RetryAllocator>(
124+
std::move(unmanaged_allocator), static_cast<size_t>(retry_time_));
125+
return std::make_shared<AlignedAllocator<64u>>(tmp);
136126
}
137127
}
138128

@@ -142,8 +132,8 @@ class ChunkedManagedAllocator : public ManagedAllocator {
142132
size_t max_chunk_size_;
143133
int64_t retry_time_;
144134
std::vector<std::unique_ptr<Allocation>> chunks_;
145-
std::shared_ptr<ManagedAllocator> raw_allocator_;
146-
std::shared_ptr<ManagedAllocator> default_allocator_;
135+
std::shared_ptr<Allocator> raw_allocator_;
136+
std::shared_ptr<Allocator> default_allocator_;
147137
};
148138

149139
#ifdef PADDLE_WITH_CUDA
@@ -193,7 +183,7 @@ class CUDAPinnedManagedAllocator : public ChunkedManagedAllocator {
193183

194184
class AllocatorFacadePrivate {
195185
public:
196-
std::map<platform::Place, std::shared_ptr<ManagedAllocator>> allocators_;
186+
std::map<platform::Place, std::shared_ptr<Allocator>> allocators_;
197187

198188
~AllocatorFacadePrivate() = default;
199189

@@ -245,7 +235,8 @@ AllocatorFacade& AllocatorFacade::Instance() {
245235

246236
std::shared_ptr<Allocation> AllocatorFacade::AllocShared(
247237
const platform::Place& place, size_t size, Allocator::Attr attr) {
248-
return m_->allocators_.at(place)->AllocateShared(size, attr);
238+
return std::shared_ptr<Allocation>(
239+
m_->allocators_.at(place)->Allocate(size, attr).release());
249240
}
250241

251242
std::unique_ptr<Allocation> AllocatorFacade::Alloc(const platform::Place& place,

paddle/fluid/memory/allocation/auto_increment_allocator.cc

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,61 @@ namespace allocation {
2020

2121
std::unique_ptr<Allocation> AutoIncrementAllocator::Allocate(
2222
size_t size, Allocator::Attr attr) {
23-
return InvokeOrCreateUnderlyingAllocator([&](ManagedAllocator& allocator) {
24-
return allocator.Allocate(size, attr);
25-
});
26-
}
23+
auto cur = prev_success_allocator_.load();
24+
size_t retry_count = allocator_num_.load();
25+
size_t allocator_num = retry_count;
26+
while (retry_count-- > 0) { // until there retry count is zero
27+
try {
28+
auto res = underlying_allocators_[cur]->Allocate(size, attr);
29+
prev_success_allocator_ = cur;
30+
return res;
31+
} catch (BadAlloc&) {
32+
if (++cur >= allocator_num) {
33+
cur = 0;
34+
}
35+
} catch (...) {
36+
// if there is another type of allocation, just rethrow it.
37+
throw;
38+
}
39+
}
2740

28-
std::shared_ptr<Allocation> AutoIncrementAllocator::AllocateShared(
29-
size_t size, Allocator::Attr attr) {
30-
return InvokeOrCreateUnderlyingAllocator([&](ManagedAllocator& allocator) {
31-
return allocator.AllocateShared(size, attr);
32-
});
41+
// This happens when the first allocator is exhausted and
42+
// there are more than 1 allocation requests
43+
// In this situation, the first allocation request would success
44+
// and the second allocation request would fail if we do not use
45+
// the newly created allocator by the first allocation request.
46+
for (cur = allocator_num; cur < allocator_num_; ++cur) {
47+
try {
48+
auto ret = underlying_allocators_[cur]->Allocate(size, attr);
49+
prev_success_allocator_ = cur;
50+
return ret;
51+
} catch (BadAlloc&) {
52+
} catch (...) {
53+
throw;
54+
}
55+
}
56+
// No suitable allocator
57+
return CreateNewAllocator()->Allocate(size, attr);
3358
}
3459

3560
bool AutoIncrementAllocator::IsAllocThreadSafe() const { return true; }
3661

62+
std::shared_ptr<Allocator> AutoIncrementAllocator::CreateNewAllocator() {
63+
std::lock_guard<std::mutex> guard(mtx_);
64+
auto old_size = allocator_num_.load();
65+
PADDLE_ENFORCE_LT(old_size, underlying_allocators_.size(),
66+
"Allocator number exceeds capacity %d",
67+
underlying_allocators_.size());
68+
underlying_allocators_[old_size] = creator_();
69+
prev_success_allocator_ = old_size;
70+
++allocator_num_;
71+
PADDLE_ENFORCE(
72+
underlying_allocators_[old_size]->IsAllocThreadSafe(),
73+
"the underlying allocator must be thread safe. This is a program "
74+
"bug.");
75+
return underlying_allocators_[old_size];
76+
}
77+
3778
} // namespace allocation
3879
} // namespace memory
3980
} // namespace paddle

paddle/fluid/memory/allocation/auto_increment_allocator.h

Lines changed: 5 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -46,76 +46,20 @@ namespace allocation {
4646
// thread-safe std::vector with varying size is hard to implement.
4747
// Fortunately, we can get the total GPU memory and each chunk size.
4848
// Therefore, we can get the suitable capacity of AutoIncrementAllocator.
49-
class AutoIncrementAllocator : public ManagedAllocator {
49+
class AutoIncrementAllocator : public Allocator {
5050
public:
5151
// Creator is the method to create ManagedAllocator
52-
using AllocatorCreator = std::function<std::shared_ptr<ManagedAllocator>()>;
52+
using AllocatorCreator = std::function<std::shared_ptr<Allocator>()>;
5353

5454
explicit AutoIncrementAllocator(AllocatorCreator&& creator, size_t capacity)
5555
: creator_(std::move(creator)), underlying_allocators_(capacity) {}
56+
5657
std::unique_ptr<Allocation> Allocate(size_t size, Attr attr) override;
57-
std::shared_ptr<Allocation> AllocateShared(size_t size, Attr attr) override;
58+
5859
bool IsAllocThreadSafe() const override;
5960

6061
private:
61-
// NOTE: here use template Callback, it can be inlined when -O3
62-
template <typename Callback>
63-
inline typename std::result_of<Callback(ManagedAllocator&)>::type
64-
InvokeOrCreateUnderlyingAllocator(Callback callback) {
65-
auto cur = prev_success_allocator_.load();
66-
size_t retry_count = allocator_num_.load();
67-
size_t allocator_num = retry_count;
68-
while (retry_count-- > 0) { // until there retry count is zero
69-
try {
70-
auto res = callback(*underlying_allocators_[cur]);
71-
prev_success_allocator_ = cur;
72-
return std::move(res);
73-
} catch (BadAlloc&) {
74-
if (++cur >= allocator_num) {
75-
cur = 0;
76-
}
77-
} catch (...) {
78-
// if there is another type of allocation, just rethrow it.
79-
throw;
80-
}
81-
}
82-
83-
// This happens when the first allocator is exhausted and
84-
// there are more than 1 allocation requests
85-
// In this situation, the first allocation request would success
86-
// and the second allocation request would fail if we do not use
87-
// the newly created allocator by the first allocation request.
88-
for (cur = allocator_num; cur < allocator_num_; ++cur) {
89-
try {
90-
auto ret = callback(*underlying_allocators_[cur]);
91-
prev_success_allocator_ = cur;
92-
return std::move(ret);
93-
} catch (BadAlloc&) {
94-
} catch (...) {
95-
throw;
96-
}
97-
}
98-
// No suitable allocator
99-
100-
ManagedAllocator* new_allocator;
101-
{
102-
std::lock_guard<std::mutex> guard(mtx_);
103-
auto old_size = allocator_num_.load();
104-
PADDLE_ENFORCE_LT(old_size, underlying_allocators_.size(),
105-
"Allocator number exceeds capacity %d",
106-
underlying_allocators_.size());
107-
underlying_allocators_[old_size] = creator_();
108-
new_allocator = underlying_allocators_[old_size].get();
109-
prev_success_allocator_ = old_size;
110-
++allocator_num_;
111-
}
112-
113-
PADDLE_ENFORCE(
114-
new_allocator->IsAllocThreadSafe(),
115-
"the underlying allocator must be thread safe. This is a program "
116-
"bug.");
117-
return callback(*new_allocator);
118-
}
62+
std::shared_ptr<Allocator> CreateNewAllocator();
11963

12064
AllocatorCreator creator_;
12165

0 commit comments

Comments
 (0)