25
25
#include " Shared/Utils.h"
26
26
#include " omptarget.h"
27
27
28
+ #include " llvm/Support/Error.h"
29
+
30
+ namespace llvm {
31
+
28
32
// / Base class of per-device allocator.
29
33
class DeviceAllocatorTy {
30
34
public:
31
35
virtual ~DeviceAllocatorTy () = default ;
32
36
33
37
// / Allocate a memory of size \p Size . \p HstPtr is used to assist the
34
38
// / allocation.
35
- virtual void *allocate (size_t Size, void *HstPtr,
36
- TargetAllocTy Kind = TARGET_ALLOC_DEFAULT) = 0;
39
+ virtual Expected<void *>
40
+ allocate (size_t Size, void *HstPtr,
41
+ TargetAllocTy Kind = TARGET_ALLOC_DEFAULT) = 0 ;
37
42
38
43
// / Delete the pointer \p TgtPtr on the device
39
- virtual int free (void *TgtPtr, TargetAllocTy Kind = TARGET_ALLOC_DEFAULT) = 0;
44
+ virtual Error free (void *TgtPtr,
45
+ TargetAllocTy Kind = TARGET_ALLOC_DEFAULT) = 0;
40
46
};
41
47
42
48
// / Class of memory manager. The memory manager is per-device by using
@@ -134,17 +140,17 @@ class MemoryManagerTy {
134
140
size_t SizeThreshold = 1U << 13 ;
135
141
136
142
// / Request memory from target device
137
- void *allocateOnDevice (size_t Size, void *HstPtr) const {
143
+ Expected< void *> allocateOnDevice (size_t Size, void *HstPtr) const {
138
144
return DeviceAllocator.allocate (Size, HstPtr, TARGET_ALLOC_DEVICE);
139
145
}
140
146
141
147
// / Deallocate data on device
142
- int deleteOnDevice (void *Ptr) const { return DeviceAllocator.free (Ptr); }
148
+ Error deleteOnDevice (void *Ptr) const { return DeviceAllocator.free (Ptr); }
143
149
144
150
// / This function is called when it tries to allocate memory on device but the
145
151
// / device returns out of memory. It will first free all memory in the
146
152
// / FreeList and try to allocate again.
147
- void *freeAndAllocate (size_t Size, void *HstPtr) {
153
+ Expected< void *> freeAndAllocate (size_t Size, void *HstPtr) {
148
154
std::vector<void *> RemoveList;
149
155
150
156
// Deallocate all memory in FreeList
@@ -154,7 +160,8 @@ class MemoryManagerTy {
154
160
if (List.empty ())
155
161
continue ;
156
162
for (const NodeTy &N : List) {
157
- deleteOnDevice (N.Ptr );
163
+ if (auto Err = deleteOnDevice (N.Ptr ))
164
+ return Err;
158
165
RemoveList.push_back (N.Ptr );
159
166
}
160
167
FreeLists[I].clear ();
@@ -175,14 +182,22 @@ class MemoryManagerTy {
175
182
// / allocate directly on the device. If a \p nullptr is returned, it might
176
183
// / be because the device is OOM. In that case, it will free all unused
177
184
// / memory and then try again.
178
- void *allocateOrFreeAndAllocateOnDevice (size_t Size, void *HstPtr) {
179
- void *TgtPtr = allocateOnDevice (Size, HstPtr);
185
+ Expected<void *> allocateOrFreeAndAllocateOnDevice (size_t Size,
186
+ void *HstPtr) {
187
+ auto TgtPtrOrErr = allocateOnDevice (Size, HstPtr);
188
+ if (!TgtPtrOrErr)
189
+ return TgtPtrOrErr.takeError ();
190
+
191
+ void *TgtPtr = *TgtPtrOrErr;
180
192
// We cannot get memory from the device. It might be due to OOM. Let's
181
193
// free all memory in FreeLists and try again.
182
194
if (TgtPtr == nullptr ) {
183
195
DP (" Failed to get memory on device. Free all memory in FreeLists and "
184
196
" try again.\n " );
185
- TgtPtr = freeAndAllocate (Size, HstPtr);
197
+ TgtPtrOrErr = freeAndAllocate (Size, HstPtr);
198
+ if (!TgtPtrOrErr)
199
+ return TgtPtrOrErr.takeError ();
200
+ TgtPtr = *TgtPtrOrErr;
186
201
}
187
202
188
203
if (TgtPtr == nullptr )
@@ -204,16 +219,17 @@ class MemoryManagerTy {
204
219
205
220
// / Destructor
206
221
~MemoryManagerTy () {
207
- for (auto Itr = PtrToNodeTable.begin (); Itr != PtrToNodeTable.end ();
208
- ++Itr) {
209
- assert (Itr->second .Ptr && " nullptr in map table" );
210
- deleteOnDevice (Itr->second .Ptr );
222
+ for (auto &PtrToNode : PtrToNodeTable) {
223
+ assert (PtrToNode.second .Ptr && " nullptr in map table" );
224
+ if (auto Err = deleteOnDevice (PtrToNode.second .Ptr ))
225
+ REPORT (" Failure to delete memory: %s\n " ,
226
+ toString (std::move (Err)).data ());
211
227
}
212
228
}
213
229
214
230
// / Allocate memory of size \p Size from target device. \p HstPtr is used to
215
231
// / assist the allocation.
216
- void *allocate (size_t Size, void *HstPtr) {
232
+ Expected< void *> allocate (size_t Size, void *HstPtr) {
217
233
// If the size is zero, we will not bother the target device. Just return
218
234
// nullptr directly.
219
235
if (Size == 0 )
@@ -228,11 +244,14 @@ class MemoryManagerTy {
228
244
DP (" %zu is greater than the threshold %zu. Allocate it directly from "
229
245
" device\n " ,
230
246
Size, SizeThreshold);
231
- void *TgtPtr = allocateOrFreeAndAllocateOnDevice (Size, HstPtr);
247
+ auto TgtPtrOrErr = allocateOrFreeAndAllocateOnDevice (Size, HstPtr);
248
+ if (!TgtPtrOrErr)
249
+ return TgtPtrOrErr.takeError ();
232
250
233
- DP (" Got target pointer " DPxMOD " . Return directly.\n " , DPxPTR (TgtPtr));
251
+ DP (" Got target pointer " DPxMOD " . Return directly.\n " ,
252
+ DPxPTR (*TgtPtrOrErr));
234
253
235
- return TgtPtr ;
254
+ return *TgtPtrOrErr ;
236
255
}
237
256
238
257
NodeTy *NodePtr = nullptr ;
@@ -260,8 +279,11 @@ class MemoryManagerTy {
260
279
if (NodePtr == nullptr ) {
261
280
DP (" Cannot find a node in the FreeLists. Allocate on device.\n " );
262
281
// Allocate one on device
263
- void *TgtPtr = allocateOrFreeAndAllocateOnDevice (Size, HstPtr);
282
+ auto TgtPtrOrErr = allocateOrFreeAndAllocateOnDevice (Size, HstPtr);
283
+ if (!TgtPtrOrErr)
284
+ return TgtPtrOrErr.takeError ();
264
285
286
+ void *TgtPtr = *TgtPtrOrErr;
265
287
if (TgtPtr == nullptr )
266
288
return nullptr ;
267
289
@@ -282,7 +304,7 @@ class MemoryManagerTy {
282
304
}
283
305
284
306
// / Deallocate memory pointed by \p TgtPtr
285
- int free (void *TgtPtr) {
307
+ Error free (void *TgtPtr) {
286
308
DP (" MemoryManagerTy::free: target memory " DPxMOD " .\n " , DPxPTR (TgtPtr));
287
309
288
310
NodeTy *P = nullptr ;
@@ -314,7 +336,7 @@ class MemoryManagerTy {
314
336
FreeLists[B].insert (*P);
315
337
}
316
338
317
- return OFFLOAD_SUCCESS ;
339
+ return Error::success () ;
318
340
}
319
341
320
342
// / Get the size threshold from the environment variable
@@ -344,4 +366,6 @@ class MemoryManagerTy {
344
366
constexpr const size_t MemoryManagerTy::BucketSize[];
345
367
constexpr const int MemoryManagerTy::NumBuckets;
346
368
369
+ } // namespace llvm
370
+
347
371
#endif // LLVM_OPENMP_LIBOMPTARGET_PLUGINS_COMMON_MEMORYMANAGER_H
0 commit comments