99// ===----------------------------------------------------------------------===//
1010#pragma once
1111
12- #include " common.hpp"
1312#include " context.hpp"
1413#include " event.hpp"
1514#include < cassert>
15+ #include < memory>
16+ #include < unordered_map>
1617#include < variant>
1718
19+ #include " common.hpp"
20+
1821ur_result_t allocateMemObjOnDeviceIfNeeded (ur_mem_handle_t ,
1922 const ur_device_handle_t );
2023ur_result_t migrateMemoryToDeviceIfNeeded (ur_mem_handle_t ,
2124 const ur_device_handle_t );
2225
2326// Handler for plain, pointer-based HIP allocations
2427struct BufferMem {
28+ struct BufferMap {
29+ // / Size of the active mapped region.
30+ size_t MapSize;
31+ // / Offset of the active mapped region.
32+ size_t MapOffset;
33+ // / Original flags for the mapped region
34+ ur_map_flags_t MapFlags;
35+ // / Allocated host memory used exclusively for this map.
36+ std::shared_ptr<unsigned char []> MapMem;
37+
38+ BufferMap (size_t MapSize, size_t MapOffset, ur_map_flags_t MapFlags)
39+ : MapSize(MapSize), MapOffset(MapOffset), MapFlags(MapFlags),
40+ MapMem (nullptr ) {}
41+
42+ BufferMap (size_t MapSize, size_t MapOffset, ur_map_flags_t MapFlags,
43+ std::unique_ptr<unsigned char []> &&MapMem)
44+ : MapSize(MapSize), MapOffset(MapOffset), MapFlags(MapFlags),
45+ MapMem(std::move(MapMem)) {}
46+
47+ size_t getMapSize () const noexcept { return MapSize; }
48+
49+ size_t getMapOffset () const noexcept { return MapOffset; }
50+
51+ ur_map_flags_t getMapFlags () const noexcept { return MapFlags; }
52+ };
53+
54+ /* * AllocMode
55+ * Classic: Just a normal buffer allocated on the device via hip malloc
56+ * UseHostPtr: Use an address on the host for the device
57+ * CopyIn: The data for the device comes from the host but the host
58+ pointer is not available later for re-use
59+ * AllocHostPtr: Uses pinned-memory allocation
60+ */
61+ enum class AllocMode { Classic, UseHostPtr, CopyIn, AllocHostPtr };
62+
2563 using native_type = hipDeviceptr_t;
2664
2765 // If this allocation is a sub-buffer (i.e., a view on an existing
2866 // allocation), this is the pointer to the parent handler structure
2967 ur_mem_handle_t Parent = nullptr ;
3068 // Outer mem holding this struct in variant
3169 ur_mem_handle_t OuterMemStruct;
32-
3370 // / Pointer associated with this device on the host
3471 void *HostPtr;
3572 // / Size of the allocation in bytes
3673 size_t Size;
37- // / Size of the active mapped region.
38- size_t MapSize;
39- // / Offset of the active mapped region.
40- size_t MapOffset;
41- // / Pointer to the active mapped region, if any
42- void *MapPtr;
43- // / Original flags for the mapped region
44- ur_map_flags_t MapFlags;
74+ // / A map that contains all the active mappings for this buffer.
75+ std::unordered_map<void *, BufferMap> PtrToBufferMap;
4576
46- /* * AllocMode
47- * Classic: Just a normal buffer allocated on the device via hip malloc
48- * UseHostPtr: Use an address on the host for the device
49- * CopyIn: The data for the device comes from the host but the host
50- pointer is not available later for re-use
51- * AllocHostPtr: Uses pinned-memory allocation
52- */
53- enum class AllocMode {
54- Classic,
55- UseHostPtr,
56- CopyIn,
57- AllocHostPtr
58- } MemAllocMode;
77+ AllocMode MemAllocMode;
5978
6079private:
6180 // Vector of HIP pointers
@@ -65,10 +84,8 @@ struct BufferMem {
6584 BufferMem (ur_context_handle_t Context, ur_mem_handle_t OuterMemStruct,
6685 AllocMode Mode, void *HostPtr, size_t Size)
6786 : OuterMemStruct{OuterMemStruct}, HostPtr{HostPtr}, Size{Size},
68- MapSize{0 }, MapOffset{0 }, MapPtr{nullptr }, MapFlags{UR_MAP_FLAG_WRITE},
69- MemAllocMode{Mode}, Ptrs(Context->Devices.size(), native_type{0 }){};
70-
71- BufferMem (const BufferMem &Buffer) = default ;
87+ PtrToBufferMap{}, MemAllocMode{Mode},
88+ Ptrs (Context->Devices.size(), native_type{0 }){};
7289
7390 // This will allocate memory on device if there isn't already an active
7491 // allocation on the device
@@ -98,45 +115,41 @@ struct BufferMem {
98115
99116 size_t getSize () const noexcept { return Size; }
100117
101- void *getMapPtr () const noexcept { return MapPtr; }
102-
103- size_t getMapSize () const noexcept { return MapSize; }
104-
105- size_t getMapOffset () const noexcept { return MapOffset; }
118+ BufferMap *getMapDetails (void *Map) {
119+ auto details = PtrToBufferMap.find (Map);
120+ if (details != PtrToBufferMap.end ()) {
121+ return &details->second ;
122+ }
123+ return nullptr ;
124+ }
106125
107126 // / Returns a pointer to data visible on the host that contains
108127 // / the data on the device associated with this allocation.
109128 // / The offset is used to index into the HIP allocation.
110129 // /
111- void *mapToPtr (size_t Size, size_t Offset, ur_map_flags_t Flags) noexcept {
112- assert (MapPtr == nullptr );
113- MapSize = Size;
114- MapOffset = Offset;
115- MapFlags = Flags;
116- if (HostPtr) {
117- MapPtr = static_cast <char *>(HostPtr) + Offset;
130+ void *mapToPtr (size_t MapSize, size_t MapOffset,
131+ ur_map_flags_t MapFlags) noexcept {
132+ void *MapPtr = nullptr ;
133+ if (HostPtr == nullptr ) {
134+ // / If HostPtr is invalid, we need to create a Mapping that owns its own
135+ // / memory on the host.
136+ auto MapMem = std::make_unique<unsigned char []>(MapSize);
137+ MapPtr = MapMem.get ();
138+ PtrToBufferMap.insert (
139+ {MapPtr, BufferMap (MapSize, MapOffset, MapFlags, std::move (MapMem))});
118140 } else {
119- // TODO: Allocate only what is needed based on the offset
120- MapPtr = static_cast <void *>(malloc (this ->getSize ()));
141+ // / However, if HostPtr already has valid memory (e.g. pinned allocation),
142+ // / we can just use that memory for the mapping.
143+ MapPtr = static_cast <char *>(HostPtr) + MapOffset;
144+ PtrToBufferMap.insert ({MapPtr, BufferMap (MapSize, MapOffset, MapFlags)});
121145 }
122146 return MapPtr;
123147 }
124148
125149 // / Detach the allocation from the host memory.
126- void unmap (void *) noexcept {
150+ void unmap (void *MapPtr ) noexcept {
127151 assert (MapPtr != nullptr );
128-
129- if (MapPtr != HostPtr) {
130- free (MapPtr);
131- }
132- MapPtr = nullptr ;
133- MapSize = 0 ;
134- MapOffset = 0 ;
135- }
136-
137- ur_map_flags_t getMapFlags () const noexcept {
138- assert (MapPtr != nullptr );
139- return MapFlags;
152+ PtrToBufferMap.erase (MapPtr);
140153 }
141154
142155 ur_result_t clear () {
@@ -414,7 +427,7 @@ struct ur_mem_handle_t_ {
414427 HaveMigratedToDeviceSinceLastWrite (Context->Devices.size(), false ),
415428 Mem{std::in_place_type<BufferMem>, Ctxt, this , Mode, HostPtr, Size} {
416429 urContextRetain (Context);
417- };
430+ }
418431
419432 // Subbuffer constructor
420433 ur_mem_handle_t_ (ur_mem Parent, size_t SubBufferOffset)
@@ -435,7 +448,7 @@ struct ur_mem_handle_t_ {
435448 }
436449 }
437450 urMemRetain (Parent);
438- };
451+ }
439452
440453 // / Constructs the UR mem handler for an Image object
441454 ur_mem_handle_t_ (ur_context Ctxt, ur_mem_flags_t MemFlags,
0 commit comments