11#pragma once
22
3- #include < c10/core/Allocator.h>
4- #include < c10/core/DeviceType.h>
3+ #include < c10/hip/HIPCachingAllocator.h>
54
65// Use of c10::hip namespace here makes hipification easier, because
76// I don't have to also fix namespaces. Sorry!
@@ -10,22 +9,206 @@ namespace c10::hip {
109// Takes a valid HIPAllocator (of any sort) and turns it into
1110// an allocator pretending to be a CUDA allocator. See
1211// Note [Masquerading as CUDA]
13- class HIPAllocatorMasqueradingAsCUDA final : public Allocator {
14- Allocator * allocator_;
12+ class HIPAllocatorMasqueradingAsCUDA final : public HIPCachingAllocator::HIPAllocator {
13+ HIPCachingAllocator::HIPAllocator * allocator_;
1514public:
16- explicit HIPAllocatorMasqueradingAsCUDA (Allocator * allocator)
15+ explicit HIPAllocatorMasqueradingAsCUDA (HIPCachingAllocator::HIPAllocator * allocator)
1716 : allocator_(allocator) {}
17+
18+ virtual ~HIPAllocatorMasqueradingAsCUDA () = default ;
19+
20+ // From c10::Allocator
21+
1822 DataPtr allocate (size_t size) override {
1923 DataPtr r = allocator_->allocate (size);
2024 r.unsafe_set_device (Device (c10::DeviceType::CUDA, r.device ().index ()));
2125 return r;
2226 }
27+
28+ bool is_simple_data_ptr (const DataPtr& data_ptr) const override {
29+ return allocator_->is_simple_data_ptr (data_ptr);
30+ }
31+
2332 DeleterFnPtr raw_deleter () const override {
2433 return allocator_->raw_deleter ();
2534 }
35+
2636 void copy_data (void * dest, const void * src, std::size_t count) const final {
2737 allocator_->copy_data (dest, src, count);
2838 }
39+
40+ // From CUDAAllocator
41+
42+ void * raw_alloc (size_t nbytes) override {
43+ return allocator_->raw_alloc (nbytes);
44+ }
45+
46+ void * raw_alloc_with_stream (size_t nbytes, hipStream_t stream) override {
47+ return allocator_->raw_alloc_with_stream (nbytes, stream);
48+ }
49+
50+ void raw_delete (void * ptr) override {
51+ allocator_->raw_delete (ptr);
52+ }
53+
54+ void init (int device_count) override {
55+ allocator_->init (device_count);
56+ }
57+
58+ bool initialized () override {
59+ return allocator_->initialized ();
60+ }
61+
62+ double getMemoryFraction (c10::DeviceIndex device) override {
63+ return allocator_->getMemoryFraction (device);
64+ }
65+
66+ void setMemoryFraction (double fraction, c10::DeviceIndex device) override {
67+ allocator_->setMemoryFraction (fraction, device);
68+ }
69+
70+ void emptyCache () override {
71+ allocator_->emptyCache ();
72+ }
73+
74+ void enable (bool value) override {
75+ allocator_->enable (value);
76+ }
77+
78+ bool isEnabled () const override {
79+ return allocator_->isEnabled ();
80+ }
81+
82+ void cacheInfo (c10::DeviceIndex device, size_t * largestBlock) override {
83+ allocator_->cacheInfo (device, largestBlock);
84+ }
85+
86+ void * getBaseAllocation (void * ptr, size_t * size) override {
87+ return allocator_->getBaseAllocation (ptr, size);
88+ }
89+
90+ void recordStream (const DataPtr& ptr, HIPStream stream) override {
91+ allocator_->recordStream (ptr, stream);
92+ }
93+
94+ CachingDeviceAllocator::DeviceStats getDeviceStats (c10::DeviceIndex device) override {
95+ return allocator_->getDeviceStats (device);
96+ }
97+
98+ void resetAccumulatedStats (c10::DeviceIndex device) override {
99+ allocator_->resetAccumulatedStats (device);
100+ }
101+
102+ void resetPeakStats (c10::DeviceIndex device) override {
103+ allocator_->resetPeakStats (device);
104+ }
105+
106+ HIPCachingAllocator::SnapshotInfo snapshot () override {
107+ return allocator_->snapshot ();
108+ }
109+
110+ void beginAllocateToPool (
111+ c10::DeviceIndex device,
112+ MempoolId_t mempool_id,
113+ std::function<bool (hipStream_t)> filter) override {
114+ allocator_->beginAllocateToPool (device, mempool_id, filter);
115+ }
116+
117+ void endAllocateToPool (
118+ c10::DeviceIndex device,
119+ MempoolId_t mempool_id) override {
120+ allocator_->endAllocateToPool (device, mempool_id);
121+ }
122+
123+ void releasePool (c10::DeviceIndex device, MempoolId_t mempool_id) override {
124+ allocator_->releasePool (device, mempool_id);
125+ }
126+
127+ int getPoolUseCount (c10::DeviceIndex device, MempoolId_t mempool_id) override {
128+ return allocator_->getPoolUseCount (device, mempool_id);
129+ }
130+
131+ void ensureExistsAndIncrefPool (
132+ c10::DeviceIndex device,
133+ MempoolId_t mempool_id) override {
134+ allocator_->ensureExistsAndIncrefPool (device, mempool_id);
135+ }
136+
137+ bool checkPoolLiveAllocations (
138+ c10::DeviceIndex device,
139+ MempoolId_t mempool_id,
140+ const std::unordered_set<void *>& expected_live_allocations) override {
141+ return allocator_->checkPoolLiveAllocations (device, mempool_id, expected_live_allocations);
142+ }
143+
144+ HIPCachingAllocator::ShareableHandle shareIpcHandle (void * ptr) override {
145+ return allocator_->shareIpcHandle (ptr);
146+ }
147+
148+ std::shared_ptr<void > getIpcDevPtr (std::string handle) override {
149+ return allocator_->getIpcDevPtr (handle);
150+ }
151+
152+ bool isHistoryEnabled () override {
153+ return allocator_->isHistoryEnabled ();
154+ }
155+
156+ void recordHistory (
157+ bool enabled,
158+ HIPCachingAllocator::CreateContextFn context_recorder,
159+ size_t alloc_trace_max_entries,
160+ HIPCachingAllocator::RecordContext when) override {
161+ allocator_->recordHistory (enabled, context_recorder, alloc_trace_max_entries, when);
162+ }
163+
164+ void recordAnnotation (
165+ const std::vector<std::pair<std::string, std::string>>& md) override {
166+ allocator_->recordAnnotation (md);
167+ }
168+
169+ void attachOutOfMemoryObserver (HIPCachingAllocator::OutOfMemoryObserver observer) override {
170+ allocator_->attachOutOfMemoryObserver (observer);
171+ }
172+
173+ void attachAllocatorTraceTracker (HIPCachingAllocator::AllocatorTraceTracker tracker) override {
174+ allocator_->attachAllocatorTraceTracker (tracker);
175+ }
176+
177+ void enablePeerAccess (c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) override {
178+ allocator_->enablePeerAccess (dev, dev_to_access);
179+ }
180+
181+ hipError_t memcpyAsync (
182+ void * dst,
183+ int dstDevice,
184+ const void * src,
185+ int srcDevice,
186+ size_t count,
187+ hipStream_t stream,
188+ bool p2p_enabled) override {
189+ return allocator_->memcpyAsync (dst, dstDevice, src, srcDevice, count, stream, p2p_enabled);
190+ }
191+
192+ std::shared_ptr<HIPCachingAllocator::AllocatorState> getCheckpointState (
193+ c10::DeviceIndex device,
194+ MempoolId_t id) override {
195+ return allocator_->getCheckpointState (device, id);
196+ }
197+
198+ HIPCachingAllocator::CheckpointDelta setCheckpointPoolState (
199+ c10::DeviceIndex device,
200+ std::shared_ptr<HIPCachingAllocator::AllocatorState> pps) override {
201+ auto cpd = allocator_->setCheckpointPoolState (device, pps);
202+ for (auto & ptr : cpd.dataptrs_allocd ) {
203+ ptr.unsafe_set_device (Device (c10::DeviceType::CUDA, ptr.device ().index ()));
204+ }
205+ return cpd;
206+ }
207+
208+ std::string name () override {
209+ return allocator_->name ();
210+ }
211+
29212};
30213
31214} // namespace c10::hip
0 commit comments