@@ -70,34 +70,41 @@ struct ur_mem_handle_t_ : _ur_object {
7070 // Keeps device of this memory handle
7171 ur_device_handle_t UrDevice;
7272
73+ // Whether this is an image or buffer
74+ enum mem_type_t { image, buffer };
75+ mem_type_t mem_type;
76+
7377 // Enumerates all possible types of accesses.
7478 enum access_mode_t { unknown, read_write, read_only, write_only };
7579
7680 // Interface of the _ur_mem object
7781
7882 // Get the Level Zero handle of the current memory object
79- virtual ur_result_t getZeHandle (char *&ZeHandle, access_mode_t ,
80- ur_device_handle_t Device,
81- const ur_event_handle_t *phWaitEvents,
82- uint32_t numWaitEvents) = 0 ;
83+ ur_result_t getZeHandle (char *&ZeHandle, access_mode_t ,
84+ ur_device_handle_t Device,
85+ const ur_event_handle_t *phWaitEvents,
86+ uint32_t numWaitEvents);
8387
8488 // Get a pointer to the Level Zero handle of the current memory object
85- virtual ur_result_t getZeHandlePtr (char **&ZeHandlePtr, access_mode_t ,
86- ur_device_handle_t Device,
87- const ur_event_handle_t *phWaitEvents,
88- uint32_t numWaitEvents) = 0 ;
89+ ur_result_t getZeHandlePtr (char **&ZeHandlePtr, access_mode_t ,
90+ ur_device_handle_t Device,
91+ const ur_event_handle_t *phWaitEvents,
92+ uint32_t numWaitEvents);
8993
9094 // Method to get type of the derived object (image or buffer)
91- virtual bool isImage () const = 0;
92-
93- virtual ~ur_mem_handle_t_ () = default ;
95+ bool isImage () const { return mem_type == mem_type_t ::image; }
9496
9597protected:
96- ur_mem_handle_t_ (ur_context_handle_t Context)
97- : UrContext{Context}, UrDevice{nullptr } {}
98+ ur_mem_handle_t_ (mem_type_t type, ur_context_handle_t Context)
99+ : UrContext{Context}, UrDevice{nullptr }, mem_type(type) {}
98100
99- ur_mem_handle_t_ (ur_context_handle_t Context, ur_device_handle_t Device)
100- : UrContext{Context}, UrDevice(Device) {}
101+ ur_mem_handle_t_ (mem_type_t type, ur_context_handle_t Context,
102+ ur_device_handle_t Device)
103+ : UrContext{Context}, UrDevice(Device), mem_type(type) {}
104+
105+ // Since the destructor isn't virtual, callers must destruct it via _ur_buffer
106+ // or _ur_image
107+ ~ur_mem_handle_t_ () {};
101108};
102109
103110struct _ur_buffer final : ur_mem_handle_t_ {
@@ -110,7 +117,7 @@ struct _ur_buffer final : ur_mem_handle_t_ {
110117
111118 // Sub-buffer constructor
112119 _ur_buffer (_ur_buffer *Parent, size_t Origin, size_t Size)
113- : ur_mem_handle_t_(Parent->UrContext), Size(Size),
120+ : ur_mem_handle_t_(mem_type_t ::buffer, Parent->UrContext), Size(Size),
114121 SubBuffer{{Parent, Origin}} {
115122 // Retain the Parent Buffer due to the Creation of the SubBuffer.
116123 Parent->RefCount .increment ();
@@ -127,16 +134,15 @@ struct _ur_buffer final : ur_mem_handle_t_ {
127134 // up-to-date and any data copies needed for that are performed under
128135 // the hood.
129136 //
130- virtual ur_result_t getZeHandle (char *&ZeHandle, access_mode_t ,
131- ur_device_handle_t Device,
132- const ur_event_handle_t *phWaitEvents,
133- uint32_t numWaitEvents) override ;
134- virtual ur_result_t getZeHandlePtr (char **&ZeHandlePtr, access_mode_t ,
135- ur_device_handle_t Device,
136- const ur_event_handle_t *phWaitEvents,
137- uint32_t numWaitEvents) override ;
137+ ur_result_t getBufferZeHandle (char *&ZeHandle, access_mode_t ,
138+ ur_device_handle_t Device,
139+ const ur_event_handle_t *phWaitEvents,
140+ uint32_t numWaitEvents);
141+ ur_result_t getBufferZeHandlePtr (char **&ZeHandlePtr, access_mode_t ,
142+ ur_device_handle_t Device,
143+ const ur_event_handle_t *phWaitEvents,
144+ uint32_t numWaitEvents);
138145
139- bool isImage () const override { return false ; }
140146 bool isSubBuffer () const { return SubBuffer != std::nullopt ; }
141147
142148 // Frees all allocations made for the buffer.
@@ -206,35 +212,33 @@ struct _ur_buffer final : ur_mem_handle_t_ {
206212struct _ur_image final : ur_mem_handle_t_ {
207213 // Image constructor
208214 _ur_image (ur_context_handle_t UrContext, ze_image_handle_t ZeImage)
209- : ur_mem_handle_t_(UrContext), ZeImage{ZeImage} {}
215+ : ur_mem_handle_t_(mem_type_t ::image, UrContext), ZeImage{ZeImage} {}
210216
211217 _ur_image (ur_context_handle_t UrContext, ze_image_handle_t ZeImage,
212218 bool OwnZeMemHandle)
213- : ur_mem_handle_t_(UrContext), ZeImage{ZeImage} {
219+ : ur_mem_handle_t_(mem_type_t ::image, UrContext), ZeImage{ZeImage} {
214220 OwnNativeHandle = OwnZeMemHandle;
215221 }
216222
217- virtual ur_result_t getZeHandle (char *&ZeHandle, access_mode_t ,
218- ur_device_handle_t ,
219- const ur_event_handle_t *phWaitEvents,
220- uint32_t numWaitEvents) override {
223+ ur_result_t getImageZeHandle (char *&ZeHandle, access_mode_t ,
224+ ur_device_handle_t ,
225+ const ur_event_handle_t *phWaitEvents,
226+ uint32_t numWaitEvents) {
221227 std::ignore = phWaitEvents;
222228 std::ignore = numWaitEvents;
223229 ZeHandle = reinterpret_cast <char *>(ZeImage);
224230 return UR_RESULT_SUCCESS;
225231 }
226- virtual ur_result_t getZeHandlePtr (char **&ZeHandlePtr, access_mode_t ,
227- ur_device_handle_t ,
228- const ur_event_handle_t *phWaitEvents,
229- uint32_t numWaitEvents) override {
232+ ur_result_t getImageZeHandlePtr (char **&ZeHandlePtr, access_mode_t ,
233+ ur_device_handle_t ,
234+ const ur_event_handle_t *phWaitEvents,
235+ uint32_t numWaitEvents) {
230236 std::ignore = phWaitEvents;
231237 std::ignore = numWaitEvents;
232238 ZeHandlePtr = reinterpret_cast <char **>(&ZeImage);
233239 return UR_RESULT_SUCCESS;
234240 }
235241
236- bool isImage () const override { return true ; }
237-
238242 // Keep the descriptor of the image
239243 ZeStruct<ze_image_desc_t > ZeImageDesc;
240244
0 commit comments