Skip to content

Commit 12b6b94

Browse files
committed
[ROCm] revamp HIPCachingAllocatorMasqueradingAsCUDA (pytorch#161221)
HIPAllocatorMasqueradingAsCUDA and HIPCachingAllocatorMasqueradingAsCUDA are now proper complete wrappers of HIPAllocator and HIPCachingAllocator, respectively. HIPAllocatorMasqueradingAsCUDA now subclasses HIPAllocator instead of Allocator. This fixes usability of hipify replacing c10::cuda::CUDACachingAllocator::get() where callers expect a CUDAAllocator to be returned but instead were getting a very thin Allocator shim instead. This also fixes using cudagraph trees with torch compile. The hip:0 device was not being replaced by the cuda:0 device in all methods. Pull Request resolved: pytorch#161221 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <[email protected]>
1 parent 9a46fc8 commit 12b6b94

File tree

4 files changed

+685
-8
lines changed

4 files changed

+685
-8
lines changed
Lines changed: 209 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#pragma once
22

3-
#include <c10/core/Allocator.h>
4-
#include <c10/core/DeviceType.h>
3+
#include <c10/hip/HIPCachingAllocator.h>
54

65
// Use of c10::hip namespace here makes hipification easier, because
76
// I don't have to also fix namespaces. Sorry!
@@ -10,22 +9,227 @@ namespace c10::hip {
109
// Takes a valid HIPAllocator (of any sort) and turns it into
1110
// an allocator pretending to be a CUDA allocator. See
1211
// Note [Masquerading as CUDA]
13-
class HIPAllocatorMasqueradingAsCUDA final : public Allocator {
14-
Allocator* allocator_;
12+
class HIPAllocatorMasqueradingAsCUDA final : public HIPCachingAllocator::HIPAllocator {
13+
HIPCachingAllocator::HIPAllocator* allocator_;
1514
public:
16-
explicit HIPAllocatorMasqueradingAsCUDA(Allocator* allocator)
15+
explicit HIPAllocatorMasqueradingAsCUDA(HIPCachingAllocator::HIPAllocator* allocator)
1716
: allocator_(allocator) {}
17+
18+
virtual ~HIPAllocatorMasqueradingAsCUDA() = default;
19+
20+
// From c10::Allocator
21+
1822
DataPtr allocate(size_t size) override {
1923
DataPtr r = allocator_->allocate(size);
2024
r.unsafe_set_device(Device(c10::DeviceType::CUDA, r.device().index()));
2125
return r;
2226
}
27+
28+
bool is_simple_data_ptr(const DataPtr& data_ptr) const override {
29+
return allocator_->is_simple_data_ptr(data_ptr);
30+
}
31+
2332
DeleterFnPtr raw_deleter() const override {
2433
return allocator_->raw_deleter();
2534
}
35+
2636
void copy_data(void* dest, const void* src, std::size_t count) const final {
2737
allocator_->copy_data(dest, src, count);
2838
}
39+
40+
// From DeviceAllocator
41+
42+
bool initialized() override {
43+
return allocator_->initialized();
44+
}
45+
46+
void emptyCache(MempoolId_t mempool_id = {0, 0}) override {
47+
allocator_->emptyCache(mempool_id);
48+
}
49+
50+
void recordStream(const DataPtr& ptr, c10::Stream stream) override {
51+
HIPStream hip_stream = HIPStream(stream);
52+
recordStream(ptr, hip_stream);
53+
}
54+
55+
CachingDeviceAllocator::DeviceStats getDeviceStats(c10::DeviceIndex device) override {
56+
return allocator_->getDeviceStats(device);
57+
}
58+
59+
void resetAccumulatedStats(c10::DeviceIndex device) override {
60+
allocator_->resetAccumulatedStats(device);
61+
}
62+
63+
void resetPeakStats(c10::DeviceIndex device) override {
64+
allocator_->resetPeakStats(device);
65+
}
66+
67+
// From CUDAAllocator
68+
69+
void* raw_alloc(size_t nbytes) override {
70+
return allocator_->raw_alloc(nbytes);
71+
}
72+
73+
void* raw_alloc_with_stream(size_t nbytes, hipStream_t stream) override {
74+
return allocator_->raw_alloc_with_stream(nbytes, stream);
75+
}
76+
77+
void raw_delete(void* ptr) override {
78+
allocator_->raw_delete(ptr);
79+
}
80+
81+
void init(int device_count) override {
82+
allocator_->init(device_count);
83+
}
84+
85+
double getMemoryFraction(c10::DeviceIndex device) override {
86+
return allocator_->getMemoryFraction(device);
87+
}
88+
89+
void setMemoryFraction(double fraction, c10::DeviceIndex device) override {
90+
allocator_->setMemoryFraction(fraction, device);
91+
}
92+
93+
void enable(bool value) override {
94+
allocator_->enable(value);
95+
}
96+
97+
bool isEnabled() const override {
98+
return allocator_->isEnabled();
99+
}
100+
101+
void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) override {
102+
allocator_->cacheInfo(device, largestBlock);
103+
}
104+
105+
void* getBaseAllocation(void* ptr, size_t* size) override {
106+
return allocator_->getBaseAllocation(ptr, size);
107+
}
108+
109+
void recordStream(const DataPtr& ptr, HIPStream stream) override {
110+
allocator_->recordStream(ptr, stream);
111+
}
112+
113+
HIPCachingAllocator::SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) override {
114+
return allocator_->snapshot(mempool_id);
115+
}
116+
117+
void beginAllocateToPool(
118+
c10::DeviceIndex device,
119+
MempoolId_t mempool_id,
120+
std::function<bool(hipStream_t)> filter) override {
121+
allocator_->beginAllocateToPool(device, mempool_id, filter);
122+
}
123+
124+
void endAllocateToPool(
125+
c10::DeviceIndex device,
126+
MempoolId_t mempool_id) override {
127+
allocator_->endAllocateToPool(device, mempool_id);
128+
}
129+
130+
void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) override {
131+
allocator_->releasePool(device, mempool_id);
132+
}
133+
134+
int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) override {
135+
return allocator_->getPoolUseCount(device, mempool_id);
136+
}
137+
138+
void createOrIncrefPool(
139+
c10::DeviceIndex device,
140+
MempoolId_t mempool_id,
141+
HIPAllocator* allocator = nullptr) override {
142+
allocator_->createOrIncrefPool(device, mempool_id, allocator);
143+
}
144+
145+
void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) override {
146+
allocator_->setUseOnOOM(device, mempool_id);
147+
}
148+
149+
bool checkPoolLiveAllocations(
150+
c10::DeviceIndex device,
151+
MempoolId_t mempool_id,
152+
const std::unordered_set<void*>& expected_live_allocations) override {
153+
return allocator_->checkPoolLiveAllocations(device, mempool_id, expected_live_allocations);
154+
}
155+
156+
HIPCachingAllocator::ShareableHandle shareIpcHandle(void* ptr) override {
157+
return allocator_->shareIpcHandle(ptr);
158+
}
159+
160+
std::shared_ptr<void> getIpcDevPtr(std::string handle) override {
161+
return allocator_->getIpcDevPtr(handle);
162+
}
163+
164+
bool isHistoryEnabled() override {
165+
return allocator_->isHistoryEnabled();
166+
}
167+
168+
void recordHistory(
169+
bool enabled,
170+
HIPCachingAllocator::CreateContextFn context_recorder,
171+
size_t alloc_trace_max_entries,
172+
HIPCachingAllocator::RecordContext when,
173+
bool clearHistory) override {
174+
allocator_->recordHistory(enabled, context_recorder, alloc_trace_max_entries, when, clearHistory);
175+
}
176+
177+
void recordAnnotation(
178+
const std::vector<std::pair<std::string, std::string>>& md) override {
179+
allocator_->recordAnnotation(md);
180+
}
181+
182+
void pushCompileContext(std::string& md) override {
183+
allocator_->pushCompileContext(md);
184+
}
185+
186+
void popCompileContext() override {
187+
allocator_->popCompileContext();
188+
}
189+
190+
void attachOutOfMemoryObserver(HIPCachingAllocator::OutOfMemoryObserver observer) override {
191+
allocator_->attachOutOfMemoryObserver(observer);
192+
}
193+
194+
void attachAllocatorTraceTracker(HIPCachingAllocator::AllocatorTraceTracker tracker) override {
195+
allocator_->attachAllocatorTraceTracker(tracker);
196+
}
197+
198+
void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) override {
199+
allocator_->enablePeerAccess(dev, dev_to_access);
200+
}
201+
202+
hipError_t memcpyAsync(
203+
void* dst,
204+
int dstDevice,
205+
const void* src,
206+
int srcDevice,
207+
size_t count,
208+
hipStream_t stream,
209+
bool p2p_enabled) override {
210+
return allocator_->memcpyAsync(dst, dstDevice, src, srcDevice, count, stream, p2p_enabled);
211+
}
212+
213+
std::shared_ptr<HIPCachingAllocator::AllocatorState> getCheckpointState(
214+
c10::DeviceIndex device,
215+
MempoolId_t id) override {
216+
return allocator_->getCheckpointState(device, id);
217+
}
218+
219+
HIPCachingAllocator::CheckpointDelta setCheckpointPoolState(
220+
c10::DeviceIndex device,
221+
std::shared_ptr<HIPCachingAllocator::AllocatorState> pps) override {
222+
auto cpd = allocator_->setCheckpointPoolState(device, pps);
223+
for (auto& ptr : cpd.dataptrs_allocd) {
224+
ptr.unsafe_set_device(Device(c10::DeviceType::CUDA, ptr.device().index()));
225+
}
226+
return cpd;
227+
}
228+
229+
std::string name() override {
230+
return allocator_->name();
231+
}
232+
29233
};
30234

31235
} // namespace c10::hip

aten/src/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
#include <c10/core/Allocator.h>
1+
#include <c10/hip/HIPCachingAllocator.h>
2+
#include <ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h>
23
#include <ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h>
34

45
namespace c10 { namespace hip {
56
namespace HIPCachingAllocatorMasqueradingAsCUDA {
67

7-
Allocator* get() {
8+
HIPCachingAllocator::HIPAllocator* get() {
89
static HIPAllocatorMasqueradingAsCUDA allocator(HIPCachingAllocator::get());
910
return &allocator;
1011
}

0 commit comments

Comments
 (0)