11#pragma once
22
3- #include < c10/core/Allocator.h>
4- #include < c10/core/DeviceType.h>
3+ #include < c10/hip/HIPCachingAllocator.h>
54
65// Use of c10::hip namespace here makes hipification easier, because
76// I don't have to also fix namespaces. Sorry!
@@ -10,22 +9,220 @@ namespace c10::hip {
109// Takes a valid HIPAllocator (of any sort) and turns it into
1110// an allocator pretending to be a CUDA allocator. See
1211// Note [Masquerading as CUDA]
13- class HIPAllocatorMasqueradingAsCUDA final : public Allocator {
14- Allocator * allocator_;
12+ class HIPAllocatorMasqueradingAsCUDA final : public HIPCachingAllocator::HIPAllocator {
13+ HIPCachingAllocator::HIPAllocator * allocator_;
1514public:
16- explicit HIPAllocatorMasqueradingAsCUDA (Allocator * allocator)
15+ explicit HIPAllocatorMasqueradingAsCUDA (HIPCachingAllocator::HIPAllocator * allocator)
1716 : allocator_(allocator) {}
17+
18+ virtual ~HIPAllocatorMasqueradingAsCUDA () = default ;
19+
20+ // From c10::Allocator
21+
1822 DataPtr allocate (size_t size) override {
1923 DataPtr r = allocator_->allocate (size);
2024 r.unsafe_set_device (Device (c10::DeviceType::CUDA, r.device ().index ()));
2125 return r;
2226 }
27+
28+ bool is_simple_data_ptr (const DataPtr& data_ptr) const override {
29+ return allocator_->is_simple_data_ptr (data_ptr);
30+ }
31+
2332 DeleterFnPtr raw_deleter () const override {
2433 return allocator_->raw_deleter ();
2534 }
35+
2636 void copy_data (void * dest, const void * src, std::size_t count) const final {
2737 allocator_->copy_data (dest, src, count);
2838 }
39+
40+ // From CUDAAllocator
41+
42+ void * raw_alloc (size_t nbytes) override {
43+ return allocator_->raw_alloc (nbytes);
44+ }
45+
46+ void * raw_alloc_with_stream (size_t nbytes, hipStream_t stream) override {
47+ return allocator_->raw_alloc_with_stream (nbytes, stream);
48+ }
49+
50+ void raw_delete (void * ptr) override {
51+ allocator_->raw_delete (ptr);
52+ }
53+
54+ void init (int device_count) override {
55+ allocator_->init (device_count);
56+ }
57+
58+ bool initialized () override {
59+ return allocator_->initialized ();
60+ }
61+
62+ double getMemoryFraction (c10::DeviceIndex device) override {
63+ return allocator_->getMemoryFraction (device);
64+ }
65+
66+ void setMemoryFraction (double fraction, c10::DeviceIndex device) override {
67+ allocator_->setMemoryFraction (fraction, device);
68+ }
69+
70+ void emptyCache (MempoolId_t mempool_id = {0 , 0 }) override {
71+ allocator_->emptyCache (mempool_id);
72+ }
73+
74+ void enable (bool value) override {
75+ allocator_->enable (value);
76+ }
77+
78+ bool isEnabled () const override {
79+ return allocator_->isEnabled ();
80+ }
81+
82+ void cacheInfo (c10::DeviceIndex device, size_t * largestBlock) override {
83+ allocator_->cacheInfo (device, largestBlock);
84+ }
85+
86+ void * getBaseAllocation (void * ptr, size_t * size) override {
87+ return allocator_->getBaseAllocation (ptr, size);
88+ }
89+
90+ void recordStream (const DataPtr& ptr, HIPStream stream) override {
91+ allocator_->recordStream (ptr, stream);
92+ }
93+
94+ CachingDeviceAllocator::DeviceStats getDeviceStats (c10::DeviceIndex device) override {
95+ return allocator_->getDeviceStats (device);
96+ }
97+
98+ void resetAccumulatedStats (c10::DeviceIndex device) override {
99+ allocator_->resetAccumulatedStats (device);
100+ }
101+
102+ void resetPeakStats (c10::DeviceIndex device) override {
103+ allocator_->resetPeakStats (device);
104+ }
105+
106+ HIPCachingAllocator::SnapshotInfo snapshot (MempoolId_t mempool_id = {0 , 0 }) override {
107+ return allocator_->snapshot (mempool_id);
108+ }
109+
110+ void beginAllocateToPool (
111+ c10::DeviceIndex device,
112+ MempoolId_t mempool_id,
113+ std::function<bool (hipStream_t)> filter) override {
114+ allocator_->beginAllocateToPool (device, mempool_id, filter);
115+ }
116+
117+ void endAllocateToPool (
118+ c10::DeviceIndex device,
119+ MempoolId_t mempool_id) override {
120+ allocator_->endAllocateToPool (device, mempool_id);
121+ }
122+
123+ void releasePool (c10::DeviceIndex device, MempoolId_t mempool_id) override {
124+ allocator_->releasePool (device, mempool_id);
125+ }
126+
127+ int getPoolUseCount (c10::DeviceIndex device, MempoolId_t mempool_id) override {
128+ return allocator_->getPoolUseCount (device, mempool_id);
129+ }
130+
131+ void createOrIncrefPool (
132+ c10::DeviceIndex device,
133+ MempoolId_t mempool_id,
134+ HIPAllocator* allocator = nullptr ) override {
135+ allocator_->createOrIncrefPool (device, mempool_id, allocator);
136+ }
137+
138+ void setUseOnOOM (c10::DeviceIndex device, MempoolId_t mempool_id) override {
139+ allocator_->setUseOnOOM (device, mempool_id);
140+ }
141+
142+ bool checkPoolLiveAllocations (
143+ c10::DeviceIndex device,
144+ MempoolId_t mempool_id,
145+ const std::unordered_set<void *>& expected_live_allocations) override {
146+ return allocator_->checkPoolLiveAllocations (device, mempool_id, expected_live_allocations);
147+ }
148+
149+ HIPCachingAllocator::ShareableHandle shareIpcHandle (void * ptr) override {
150+ return allocator_->shareIpcHandle (ptr);
151+ }
152+
153+ std::shared_ptr<void > getIpcDevPtr (std::string handle) override {
154+ return allocator_->getIpcDevPtr (handle);
155+ }
156+
157+ bool isHistoryEnabled () override {
158+ return allocator_->isHistoryEnabled ();
159+ }
160+
161+ void recordHistory (
162+ bool enabled,
163+ HIPCachingAllocator::CreateContextFn context_recorder,
164+ size_t alloc_trace_max_entries,
165+ HIPCachingAllocator::RecordContext when,
166+ bool clearHistory) override {
167+ allocator_->recordHistory (enabled, context_recorder, alloc_trace_max_entries, when, clearHistory);
168+ }
169+
170+ void recordAnnotation (
171+ const std::vector<std::pair<std::string, std::string>>& md) override {
172+ allocator_->recordAnnotation (md);
173+ }
174+
175+ void pushCompileContext (std::string& md) override {
176+ allocator_->pushCompileContext (md);
177+ }
178+
179+ void popCompileContext () override {
180+ allocator_->popCompileContext ();
181+ }
182+
183+ void attachOutOfMemoryObserver (HIPCachingAllocator::OutOfMemoryObserver observer) override {
184+ allocator_->attachOutOfMemoryObserver (observer);
185+ }
186+
187+ void attachAllocatorTraceTracker (HIPCachingAllocator::AllocatorTraceTracker tracker) override {
188+ allocator_->attachAllocatorTraceTracker (tracker);
189+ }
190+
191+ void enablePeerAccess (c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) override {
192+ allocator_->enablePeerAccess (dev, dev_to_access);
193+ }
194+
195+ hipError_t memcpyAsync (
196+ void * dst,
197+ int dstDevice,
198+ const void * src,
199+ int srcDevice,
200+ size_t count,
201+ hipStream_t stream,
202+ bool p2p_enabled) override {
203+ return allocator_->memcpyAsync (dst, dstDevice, src, srcDevice, count, stream, p2p_enabled);
204+ }
205+
206+ std::shared_ptr<HIPCachingAllocator::AllocatorState> getCheckpointState (
207+ c10::DeviceIndex device,
208+ MempoolId_t id) override {
209+ return allocator_->getCheckpointState (device, id);
210+ }
211+
212+ HIPCachingAllocator::CheckpointDelta setCheckpointPoolState (
213+ c10::DeviceIndex device,
214+ std::shared_ptr<HIPCachingAllocator::AllocatorState> pps) override {
215+ auto cpd = allocator_->setCheckpointPoolState (device, pps);
216+ for (auto & ptr : cpd.dataptrs_allocd ) {
217+ ptr.unsafe_set_device (Device (c10::DeviceType::CUDA, ptr.device ().index ()));
218+ }
219+ return cpd;
220+ }
221+
222+ std::string name () override {
223+ return allocator_->name ();
224+ }
225+
29226};
30227
31228} // namespace c10::hip
0 commit comments