11#pragma once
22
3- #include < c10/core/Allocator.h>
4- #include < c10/core/DeviceType.h>
3+ #include < c10/hip/HIPCachingAllocator.h>
54
65// Use of c10::hip namespace here makes hipification easier, because
76// I don't have to also fix namespaces. Sorry!
@@ -10,22 +9,227 @@ namespace c10::hip {
109// Takes a valid HIPAllocator (of any sort) and turns it into
1110// an allocator pretending to be a CUDA allocator. See
1211// Note [Masquerading as CUDA]
13- class HIPAllocatorMasqueradingAsCUDA final : public Allocator {
14- Allocator * allocator_;
12+ class HIPAllocatorMasqueradingAsCUDA final : public HIPCachingAllocator::HIPAllocator {
13+ HIPCachingAllocator::HIPAllocator * allocator_;
1514public:
16- explicit HIPAllocatorMasqueradingAsCUDA (Allocator * allocator)
15+ explicit HIPAllocatorMasqueradingAsCUDA (HIPCachingAllocator::HIPAllocator * allocator)
1716 : allocator_(allocator) {}
17+
18+ virtual ~HIPAllocatorMasqueradingAsCUDA () = default ;
19+
20+ // From c10::Allocator
21+
1822 DataPtr allocate (size_t size) override {
1923 DataPtr r = allocator_->allocate (size);
2024 r.unsafe_set_device (Device (c10::DeviceType::CUDA, r.device ().index ()));
2125 return r;
2226 }
27+
28+ bool is_simple_data_ptr (const DataPtr& data_ptr) const override {
29+ return allocator_->is_simple_data_ptr (data_ptr);
30+ }
31+
2332 DeleterFnPtr raw_deleter () const override {
2433 return allocator_->raw_deleter ();
2534 }
35+
2636 void copy_data (void * dest, const void * src, std::size_t count) const final {
2737 allocator_->copy_data (dest, src, count);
2838 }
39+
40+ // From DeviceAllocator
41+
42+ bool initialized () override {
43+ return allocator_->initialized ();
44+ }
45+
46+ void emptyCache (MempoolId_t mempool_id = {0 , 0 }) override {
47+ allocator_->emptyCache (mempool_id);
48+ }
49+
50+ void recordStream (const DataPtr& ptr, c10::Stream stream) override {
51+ HIPStream hip_stream = HIPStream (stream);
52+ recordStream (ptr, hip_stream);
53+ }
54+
55+ CachingDeviceAllocator::DeviceStats getDeviceStats (c10::DeviceIndex device) override {
56+ return allocator_->getDeviceStats (device);
57+ }
58+
59+ void resetAccumulatedStats (c10::DeviceIndex device) override {
60+ allocator_->resetAccumulatedStats (device);
61+ }
62+
63+ void resetPeakStats (c10::DeviceIndex device) override {
64+ allocator_->resetPeakStats (device);
65+ }
66+
67+ // From CUDAAllocator
68+
69+ void * raw_alloc (size_t nbytes) override {
70+ return allocator_->raw_alloc (nbytes);
71+ }
72+
73+ void * raw_alloc_with_stream (size_t nbytes, hipStream_t stream) override {
74+ return allocator_->raw_alloc_with_stream (nbytes, stream);
75+ }
76+
77+ void raw_delete (void * ptr) override {
78+ allocator_->raw_delete (ptr);
79+ }
80+
81+ void init (int device_count) override {
82+ allocator_->init (device_count);
83+ }
84+
85+ double getMemoryFraction (c10::DeviceIndex device) override {
86+ return allocator_->getMemoryFraction (device);
87+ }
88+
89+ void setMemoryFraction (double fraction, c10::DeviceIndex device) override {
90+ allocator_->setMemoryFraction (fraction, device);
91+ }
92+
93+ void enable (bool value) override {
94+ allocator_->enable (value);
95+ }
96+
97+ bool isEnabled () const override {
98+ return allocator_->isEnabled ();
99+ }
100+
101+ void cacheInfo (c10::DeviceIndex device, size_t * largestBlock) override {
102+ allocator_->cacheInfo (device, largestBlock);
103+ }
104+
105+ void * getBaseAllocation (void * ptr, size_t * size) override {
106+ return allocator_->getBaseAllocation (ptr, size);
107+ }
108+
109+ void recordStream (const DataPtr& ptr, HIPStream stream) override {
110+ allocator_->recordStream (ptr, stream);
111+ }
112+
113+ HIPCachingAllocator::SnapshotInfo snapshot (MempoolId_t mempool_id = {0 , 0 }) override {
114+ return allocator_->snapshot (mempool_id);
115+ }
116+
117+ void beginAllocateToPool (
118+ c10::DeviceIndex device,
119+ MempoolId_t mempool_id,
120+ std::function<bool (hipStream_t)> filter) override {
121+ allocator_->beginAllocateToPool (device, mempool_id, filter);
122+ }
123+
124+ void endAllocateToPool (
125+ c10::DeviceIndex device,
126+ MempoolId_t mempool_id) override {
127+ allocator_->endAllocateToPool (device, mempool_id);
128+ }
129+
130+ void releasePool (c10::DeviceIndex device, MempoolId_t mempool_id) override {
131+ allocator_->releasePool (device, mempool_id);
132+ }
133+
134+ int getPoolUseCount (c10::DeviceIndex device, MempoolId_t mempool_id) override {
135+ return allocator_->getPoolUseCount (device, mempool_id);
136+ }
137+
138+ void createOrIncrefPool (
139+ c10::DeviceIndex device,
140+ MempoolId_t mempool_id,
141+ HIPAllocator* allocator = nullptr ) override {
142+ allocator_->createOrIncrefPool (device, mempool_id, allocator);
143+ }
144+
145+ void setUseOnOOM (c10::DeviceIndex device, MempoolId_t mempool_id) override {
146+ allocator_->setUseOnOOM (device, mempool_id);
147+ }
148+
149+ bool checkPoolLiveAllocations (
150+ c10::DeviceIndex device,
151+ MempoolId_t mempool_id,
152+ const std::unordered_set<void *>& expected_live_allocations) override {
153+ return allocator_->checkPoolLiveAllocations (device, mempool_id, expected_live_allocations);
154+ }
155+
156+ HIPCachingAllocator::ShareableHandle shareIpcHandle (void * ptr) override {
157+ return allocator_->shareIpcHandle (ptr);
158+ }
159+
160+ std::shared_ptr<void > getIpcDevPtr (std::string handle) override {
161+ return allocator_->getIpcDevPtr (handle);
162+ }
163+
164+ bool isHistoryEnabled () override {
165+ return allocator_->isHistoryEnabled ();
166+ }
167+
168+ void recordHistory (
169+ bool enabled,
170+ HIPCachingAllocator::CreateContextFn context_recorder,
171+ size_t alloc_trace_max_entries,
172+ HIPCachingAllocator::RecordContext when,
173+ bool clearHistory) override {
174+ allocator_->recordHistory (enabled, context_recorder, alloc_trace_max_entries, when, clearHistory);
175+ }
176+
177+ void recordAnnotation (
178+ const std::vector<std::pair<std::string, std::string>>& md) override {
179+ allocator_->recordAnnotation (md);
180+ }
181+
182+ void pushCompileContext (std::string& md) override {
183+ allocator_->pushCompileContext (md);
184+ }
185+
186+ void popCompileContext () override {
187+ allocator_->popCompileContext ();
188+ }
189+
190+ void attachOutOfMemoryObserver (HIPCachingAllocator::OutOfMemoryObserver observer) override {
191+ allocator_->attachOutOfMemoryObserver (observer);
192+ }
193+
194+ void attachAllocatorTraceTracker (HIPCachingAllocator::AllocatorTraceTracker tracker) override {
195+ allocator_->attachAllocatorTraceTracker (tracker);
196+ }
197+
198+ void enablePeerAccess (c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) override {
199+ allocator_->enablePeerAccess (dev, dev_to_access);
200+ }
201+
202+ hipError_t memcpyAsync (
203+ void * dst,
204+ int dstDevice,
205+ const void * src,
206+ int srcDevice,
207+ size_t count,
208+ hipStream_t stream,
209+ bool p2p_enabled) override {
210+ return allocator_->memcpyAsync (dst, dstDevice, src, srcDevice, count, stream, p2p_enabled);
211+ }
212+
213+ std::shared_ptr<HIPCachingAllocator::AllocatorState> getCheckpointState (
214+ c10::DeviceIndex device,
215+ MempoolId_t id) override {
216+ return allocator_->getCheckpointState (device, id);
217+ }
218+
219+ HIPCachingAllocator::CheckpointDelta setCheckpointPoolState (
220+ c10::DeviceIndex device,
221+ std::shared_ptr<HIPCachingAllocator::AllocatorState> pps) override {
222+ auto cpd = allocator_->setCheckpointPoolState (device, pps);
223+ for (auto & ptr : cpd.dataptrs_allocd ) {
224+ ptr.unsafe_set_device (Device (c10::DeviceType::CUDA, ptr.device ().index ()));
225+ }
226+ return cpd;
227+ }
228+
229+ std::string name () override {
230+ return allocator_->name ();
231+ }
232+
29233};
30234
31235} // namespace c10::hip
0 commit comments