Skip to content

Commit e3cca5e

Browse files
authored
[release/2.7] revamp HIPCachingAllocatorMasqueradingAsCUDA (#2593)
cherry pick of pytorch#161221
1 parent 85f255b commit e3cca5e

File tree

4 files changed

+650
-8
lines changed

4 files changed

+650
-8
lines changed
Lines changed: 188 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#pragma once
22

3-
#include <c10/core/Allocator.h>
4-
#include <c10/core/DeviceType.h>
3+
#include <c10/hip/HIPCachingAllocator.h>
54

65
// Use of c10::hip namespace here makes hipification easier, because
76
// I don't have to also fix namespaces. Sorry!
@@ -10,22 +9,206 @@ namespace c10::hip {
109
// Takes a valid HIPAllocator (of any sort) and turns it into
1110
// an allocator pretending to be a CUDA allocator. See
1211
// Note [Masquerading as CUDA]
13-
class HIPAllocatorMasqueradingAsCUDA final : public Allocator {
14-
Allocator* allocator_;
12+
class HIPAllocatorMasqueradingAsCUDA final : public HIPCachingAllocator::HIPAllocator {
13+
HIPCachingAllocator::HIPAllocator* allocator_;
1514
public:
16-
explicit HIPAllocatorMasqueradingAsCUDA(Allocator* allocator)
15+
explicit HIPAllocatorMasqueradingAsCUDA(HIPCachingAllocator::HIPAllocator* allocator)
1716
: allocator_(allocator) {}
17+
18+
virtual ~HIPAllocatorMasqueradingAsCUDA() = default;
19+
20+
// From c10::Allocator
21+
1822
DataPtr allocate(size_t size) override {
1923
DataPtr r = allocator_->allocate(size);
2024
r.unsafe_set_device(Device(c10::DeviceType::CUDA, r.device().index()));
2125
return r;
2226
}
27+
28+
bool is_simple_data_ptr(const DataPtr& data_ptr) const override {
29+
return allocator_->is_simple_data_ptr(data_ptr);
30+
}
31+
2332
DeleterFnPtr raw_deleter() const override {
2433
return allocator_->raw_deleter();
2534
}
35+
2636
void copy_data(void* dest, const void* src, std::size_t count) const final {
2737
allocator_->copy_data(dest, src, count);
2838
}
39+
40+
// From CUDAAllocator
41+
42+
void* raw_alloc(size_t nbytes) override {
43+
return allocator_->raw_alloc(nbytes);
44+
}
45+
46+
void* raw_alloc_with_stream(size_t nbytes, hipStream_t stream) override {
47+
return allocator_->raw_alloc_with_stream(nbytes, stream);
48+
}
49+
50+
void raw_delete(void* ptr) override {
51+
allocator_->raw_delete(ptr);
52+
}
53+
54+
void init(int device_count) override {
55+
allocator_->init(device_count);
56+
}
57+
58+
bool initialized() override {
59+
return allocator_->initialized();
60+
}
61+
62+
double getMemoryFraction(c10::DeviceIndex device) override {
63+
return allocator_->getMemoryFraction(device);
64+
}
65+
66+
void setMemoryFraction(double fraction, c10::DeviceIndex device) override {
67+
allocator_->setMemoryFraction(fraction, device);
68+
}
69+
70+
void emptyCache() override {
71+
allocator_->emptyCache();
72+
}
73+
74+
void enable(bool value) override {
75+
allocator_->enable(value);
76+
}
77+
78+
bool isEnabled() const override {
79+
return allocator_->isEnabled();
80+
}
81+
82+
void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) override {
83+
allocator_->cacheInfo(device, largestBlock);
84+
}
85+
86+
void* getBaseAllocation(void* ptr, size_t* size) override {
87+
return allocator_->getBaseAllocation(ptr, size);
88+
}
89+
90+
void recordStream(const DataPtr& ptr, HIPStream stream) override {
91+
allocator_->recordStream(ptr, stream);
92+
}
93+
94+
CachingDeviceAllocator::DeviceStats getDeviceStats(c10::DeviceIndex device) override {
95+
return allocator_->getDeviceStats(device);
96+
}
97+
98+
void resetAccumulatedStats(c10::DeviceIndex device) override {
99+
allocator_->resetAccumulatedStats(device);
100+
}
101+
102+
void resetPeakStats(c10::DeviceIndex device) override {
103+
allocator_->resetPeakStats(device);
104+
}
105+
106+
HIPCachingAllocator::SnapshotInfo snapshot() override {
107+
return allocator_->snapshot();
108+
}
109+
110+
void beginAllocateToPool(
111+
c10::DeviceIndex device,
112+
MempoolId_t mempool_id,
113+
std::function<bool(hipStream_t)> filter) override {
114+
allocator_->beginAllocateToPool(device, mempool_id, filter);
115+
}
116+
117+
void endAllocateToPool(
118+
c10::DeviceIndex device,
119+
MempoolId_t mempool_id) override {
120+
allocator_->endAllocateToPool(device, mempool_id);
121+
}
122+
123+
void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) override {
124+
allocator_->releasePool(device, mempool_id);
125+
}
126+
127+
int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) override {
128+
return allocator_->getPoolUseCount(device, mempool_id);
129+
}
130+
131+
void ensureExistsAndIncrefPool(
132+
c10::DeviceIndex device,
133+
MempoolId_t mempool_id) override {
134+
allocator_->ensureExistsAndIncrefPool(device, mempool_id);
135+
}
136+
137+
bool checkPoolLiveAllocations(
138+
c10::DeviceIndex device,
139+
MempoolId_t mempool_id,
140+
const std::unordered_set<void*>& expected_live_allocations) override {
141+
return allocator_->checkPoolLiveAllocations(device, mempool_id, expected_live_allocations);
142+
}
143+
144+
HIPCachingAllocator::ShareableHandle shareIpcHandle(void* ptr) override {
145+
return allocator_->shareIpcHandle(ptr);
146+
}
147+
148+
std::shared_ptr<void> getIpcDevPtr(std::string handle) override {
149+
return allocator_->getIpcDevPtr(handle);
150+
}
151+
152+
bool isHistoryEnabled() override {
153+
return allocator_->isHistoryEnabled();
154+
}
155+
156+
void recordHistory(
157+
bool enabled,
158+
HIPCachingAllocator::CreateContextFn context_recorder,
159+
size_t alloc_trace_max_entries,
160+
HIPCachingAllocator::RecordContext when) override {
161+
allocator_->recordHistory(enabled, context_recorder, alloc_trace_max_entries, when);
162+
}
163+
164+
void recordAnnotation(
165+
const std::vector<std::pair<std::string, std::string>>& md) override {
166+
allocator_->recordAnnotation(md);
167+
}
168+
169+
void attachOutOfMemoryObserver(HIPCachingAllocator::OutOfMemoryObserver observer) override {
170+
allocator_->attachOutOfMemoryObserver(observer);
171+
}
172+
173+
void attachAllocatorTraceTracker(HIPCachingAllocator::AllocatorTraceTracker tracker) override {
174+
allocator_->attachAllocatorTraceTracker(tracker);
175+
}
176+
177+
void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) override {
178+
allocator_->enablePeerAccess(dev, dev_to_access);
179+
}
180+
181+
hipError_t memcpyAsync(
182+
void* dst,
183+
int dstDevice,
184+
const void* src,
185+
int srcDevice,
186+
size_t count,
187+
hipStream_t stream,
188+
bool p2p_enabled) override {
189+
return allocator_->memcpyAsync(dst, dstDevice, src, srcDevice, count, stream, p2p_enabled);
190+
}
191+
192+
std::shared_ptr<HIPCachingAllocator::AllocatorState> getCheckpointState(
193+
c10::DeviceIndex device,
194+
MempoolId_t id) override {
195+
return allocator_->getCheckpointState(device, id);
196+
}
197+
198+
HIPCachingAllocator::CheckpointDelta setCheckpointPoolState(
199+
c10::DeviceIndex device,
200+
std::shared_ptr<HIPCachingAllocator::AllocatorState> pps) override {
201+
auto cpd = allocator_->setCheckpointPoolState(device, pps);
202+
for (auto& ptr : cpd.dataptrs_allocd) {
203+
ptr.unsafe_set_device(Device(c10::DeviceType::CUDA, ptr.device().index()));
204+
}
205+
return cpd;
206+
}
207+
208+
std::string name() override {
209+
return allocator_->name();
210+
}
211+
29212
};
30213

31214
} // namespace c10::hip

aten/src/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
#include <c10/core/Allocator.h>
1+
#include <c10/hip/HIPCachingAllocator.h>
2+
#include <ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h>
23
#include <ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h>
34

45
namespace c10 { namespace hip {
56
namespace HIPCachingAllocatorMasqueradingAsCUDA {
67

7-
Allocator* get() {
8+
HIPCachingAllocator::HIPAllocator* get() {
89
static HIPAllocatorMasqueradingAsCUDA allocator(HIPCachingAllocator::get());
910
return &allocator;
1011
}

0 commit comments

Comments
 (0)