Skip to content

Commit bb06221

Browse files
committed
Add memory handle
1 parent 6b17a96 commit bb06221

File tree

6 files changed

+145
-96
lines changed

6 files changed

+145
-96
lines changed

source/adapters/opencl/command_buffer.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "command_buffer.hpp"
1212
#include "common.hpp"
1313
#include "context.hpp"
14+
#include "memory.hpp"
1415

1516
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp(
1617
ur_context_handle_t hContext, ur_device_handle_t hDevice,
@@ -165,8 +166,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp(
165166
return UR_RESULT_ERROR_INVALID_OPERATION;
166167

167168
CL_RETURN_ON_FAILURE(clCommandCopyBufferKHR(
168-
hCommandBuffer->CLCommandBuffer, nullptr,
169-
cl_adapter::cast<cl_mem>(hSrcMem), cl_adapter::cast<cl_mem>(hDstMem),
169+
hCommandBuffer->CLCommandBuffer, nullptr, hSrcMem->get(), hDstMem->get(),
170170
srcOffset, dstOffset, size, numSyncPointsInWaitList, pSyncPointWaitList,
171171
pSyncPoint, nullptr));
172172

@@ -202,8 +202,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp(
202202
return UR_RESULT_ERROR_INVALID_OPERATION;
203203

204204
CL_RETURN_ON_FAILURE(clCommandCopyBufferRectKHR(
205-
hCommandBuffer->CLCommandBuffer, nullptr,
206-
cl_adapter::cast<cl_mem>(hSrcMem), cl_adapter::cast<cl_mem>(hDstMem),
205+
hCommandBuffer->CLCommandBuffer, nullptr, hSrcMem->get(), hDstMem->get(),
207206
OpenCLOriginRect, OpenCLDstRect, OpenCLRegion, srcRowPitch, srcSlicePitch,
208207
dstRowPitch, dstSlicePitch, numSyncPointsInWaitList, pSyncPointWaitList,
209208
pSyncPoint, nullptr));
@@ -291,9 +290,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp(
291290
return UR_RESULT_ERROR_INVALID_OPERATION;
292291

293292
CL_RETURN_ON_FAILURE(clCommandFillBufferKHR(
294-
hCommandBuffer->CLCommandBuffer, nullptr,
295-
cl_adapter::cast<cl_mem>(hBuffer), pPattern, patternSize, offset, size,
296-
numSyncPointsInWaitList, pSyncPointWaitList, pSyncPoint, nullptr));
293+
hCommandBuffer->CLCommandBuffer, nullptr, hBuffer->get(), pPattern,
294+
patternSize, offset, size, numSyncPointsInWaitList, pSyncPointWaitList,
295+
pSyncPoint, nullptr));
297296

298297
return UR_RESULT_SUCCESS;
299298
}

source/adapters/opencl/context.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,12 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName,
9191
* queries of each device separately and building the intersection set. */
9292
return UR_RESULT_ERROR_INVALID_ARGUMENT;
9393
}
94-
case UR_CONTEXT_INFO_NUM_DEVICES:
95-
case UR_CONTEXT_INFO_DEVICES:
94+
case UR_CONTEXT_INFO_NUM_DEVICES: {
95+
return ReturnValue(hContext->DeviceCount);
96+
}
97+
case UR_CONTEXT_INFO_DEVICES: {
98+
return ReturnValue(hContext->Devices);
99+
}
96100
case UR_CONTEXT_INFO_REFERENCE_COUNT: {
97101
size_t CheckPropSize = 0;
98102
auto ClResult = clGetContextInfo(hContext->get(), CLPropName, propSize,

source/adapters/opencl/enqueue.cpp

Lines changed: 37 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
//===----------------------------------------------------------------------===//
1010

1111
#include "common.hpp"
12+
#include "memory.hpp"
1213
#include "program.hpp"
1314

1415
cl_map_flags convertURMapFlagsToCL(ur_map_flags_t URFlags) {
@@ -72,9 +73,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
7273
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
7374

7475
CL_RETURN_ON_FAILURE(clEnqueueReadBuffer(
75-
cl_adapter::cast<cl_command_queue>(hQueue),
76-
cl_adapter::cast<cl_mem>(hBuffer), blockingRead, offset, size, pDst,
77-
numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
76+
cl_adapter::cast<cl_command_queue>(hQueue), hBuffer->get(), blockingRead,
77+
offset, size, pDst, numEventsInWaitList,
78+
cl_adapter::cast<const cl_event *>(phEventWaitList),
7879
cl_adapter::cast<cl_event *>(phEvent)));
7980

8081
return UR_RESULT_SUCCESS;
@@ -86,9 +87,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
8687
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
8788

8889
CL_RETURN_ON_FAILURE(clEnqueueWriteBuffer(
89-
cl_adapter::cast<cl_command_queue>(hQueue),
90-
cl_adapter::cast<cl_mem>(hBuffer), blockingWrite, offset, size, pSrc,
91-
numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
90+
cl_adapter::cast<cl_command_queue>(hQueue), hBuffer->get(), blockingWrite,
91+
offset, size, pSrc, numEventsInWaitList,
92+
cl_adapter::cast<const cl_event *>(phEventWaitList),
9293
cl_adapter::cast<cl_event *>(phEvent)));
9394

9495
return UR_RESULT_SUCCESS;
@@ -107,10 +108,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
107108
const size_t Region[3] = {region.width, region.height, region.depth};
108109

109110
CL_RETURN_ON_FAILURE(clEnqueueReadBufferRect(
110-
cl_adapter::cast<cl_command_queue>(hQueue),
111-
cl_adapter::cast<cl_mem>(hBuffer), blockingRead, BufferOrigin, HostOrigin,
112-
Region, bufferRowPitch, bufferSlicePitch, hostRowPitch, hostSlicePitch,
113-
pDst, numEventsInWaitList,
111+
cl_adapter::cast<cl_command_queue>(hQueue), hBuffer->get(), blockingRead,
112+
BufferOrigin, HostOrigin, Region, bufferRowPitch, bufferSlicePitch,
113+
hostRowPitch, hostSlicePitch, pDst, numEventsInWaitList,
114114
cl_adapter::cast<const cl_event *>(phEventWaitList),
115115
cl_adapter::cast<cl_event *>(phEvent)));
116116

@@ -130,10 +130,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
130130
const size_t Region[3] = {region.width, region.height, region.depth};
131131

132132
CL_RETURN_ON_FAILURE(clEnqueueWriteBufferRect(
133-
cl_adapter::cast<cl_command_queue>(hQueue),
134-
cl_adapter::cast<cl_mem>(hBuffer), blockingWrite, BufferOrigin,
135-
HostOrigin, Region, bufferRowPitch, bufferSlicePitch, hostRowPitch,
136-
hostSlicePitch, pSrc, numEventsInWaitList,
133+
cl_adapter::cast<cl_command_queue>(hQueue), hBuffer->get(), blockingWrite,
134+
BufferOrigin, HostOrigin, Region, bufferRowPitch, bufferSlicePitch,
135+
hostRowPitch, hostSlicePitch, pSrc, numEventsInWaitList,
137136
cl_adapter::cast<const cl_event *>(phEventWaitList),
138137
cl_adapter::cast<cl_event *>(phEvent)));
139138

@@ -147,10 +146,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
147146
ur_event_handle_t *phEvent) {
148147

149148
CL_RETURN_ON_FAILURE(clEnqueueCopyBuffer(
150-
cl_adapter::cast<cl_command_queue>(hQueue),
151-
cl_adapter::cast<cl_mem>(hBufferSrc),
152-
cl_adapter::cast<cl_mem>(hBufferDst), srcOffset, dstOffset, size,
153-
numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
149+
cl_adapter::cast<cl_command_queue>(hQueue), hBufferSrc->get(),
150+
hBufferDst->get(), srcOffset, dstOffset, size, numEventsInWaitList,
151+
cl_adapter::cast<const cl_event *>(phEventWaitList),
154152
cl_adapter::cast<cl_event *>(phEvent)));
155153

156154
return UR_RESULT_SUCCESS;
@@ -168,11 +166,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
168166
const size_t Region[3] = {region.width, region.height, region.depth};
169167

170168
CL_RETURN_ON_FAILURE(clEnqueueCopyBufferRect(
171-
cl_adapter::cast<cl_command_queue>(hQueue),
172-
cl_adapter::cast<cl_mem>(hBufferSrc),
173-
cl_adapter::cast<cl_mem>(hBufferDst), SrcOrigin, DstOrigin, Region,
174-
srcRowPitch, srcSlicePitch, dstRowPitch, dstSlicePitch,
175-
numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
169+
cl_adapter::cast<cl_command_queue>(hQueue), hBufferSrc->get(),
170+
hBufferDst->get(), SrcOrigin, DstOrigin, Region, srcRowPitch,
171+
srcSlicePitch, dstRowPitch, dstSlicePitch, numEventsInWaitList,
172+
cl_adapter::cast<const cl_event *>(phEventWaitList),
176173
cl_adapter::cast<cl_event *>(phEvent)));
177174

178175
return UR_RESULT_SUCCESS;
@@ -186,12 +183,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
186183
// CL FillBuffer only allows pattern sizes up to the largest CL type:
187184
// long16/double16
188185
if (patternSize <= 128) {
189-
CL_RETURN_ON_FAILURE(
190-
clEnqueueFillBuffer(cl_adapter::cast<cl_command_queue>(hQueue),
191-
cl_adapter::cast<cl_mem>(hBuffer), pPattern,
192-
patternSize, offset, size, numEventsInWaitList,
193-
cl_adapter::cast<const cl_event *>(phEventWaitList),
194-
cl_adapter::cast<cl_event *>(phEvent)));
186+
CL_RETURN_ON_FAILURE(clEnqueueFillBuffer(
187+
cl_adapter::cast<cl_command_queue>(hQueue), hBuffer->get(), pPattern,
188+
patternSize, offset, size, numEventsInWaitList,
189+
cl_adapter::cast<const cl_event *>(phEventWaitList),
190+
cl_adapter::cast<cl_event *>(phEvent)));
195191
return UR_RESULT_SUCCESS;
196192
}
197193

@@ -204,10 +200,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
204200

205201
cl_event WriteEvent = nullptr;
206202
auto ClErr = clEnqueueWriteBuffer(
207-
cl_adapter::cast<cl_command_queue>(hQueue),
208-
cl_adapter::cast<cl_mem>(hBuffer), false, offset, size, HostBuffer,
209-
numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
210-
&WriteEvent);
203+
cl_adapter::cast<cl_command_queue>(hQueue), hBuffer->get(), false, offset,
204+
size, HostBuffer, numEventsInWaitList,
205+
cl_adapter::cast<const cl_event *>(phEventWaitList), &WriteEvent);
211206
if (ClErr != CL_SUCCESS) {
212207
delete[] HostBuffer;
213208
CL_RETURN_ON_FAILURE(ClErr);
@@ -245,9 +240,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
245240
const size_t Region[3] = {region.width, region.height, region.depth};
246241

247242
CL_RETURN_ON_FAILURE(clEnqueueReadImage(
248-
cl_adapter::cast<cl_command_queue>(hQueue),
249-
cl_adapter::cast<cl_mem>(hImage), blockingRead, Origin, Region, rowPitch,
250-
slicePitch, pDst, numEventsInWaitList,
243+
cl_adapter::cast<cl_command_queue>(hQueue), hImage->get(), blockingRead,
244+
Origin, Region, rowPitch, slicePitch, pDst, numEventsInWaitList,
251245
cl_adapter::cast<const cl_event *>(phEventWaitList),
252246
cl_adapter::cast<cl_event *>(phEvent)));
253247

@@ -263,9 +257,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite(
263257
const size_t Region[3] = {region.width, region.height, region.depth};
264258

265259
CL_RETURN_ON_FAILURE(clEnqueueWriteImage(
266-
cl_adapter::cast<cl_command_queue>(hQueue),
267-
cl_adapter::cast<cl_mem>(hImage), blockingWrite, Origin, Region, rowPitch,
268-
slicePitch, pSrc, numEventsInWaitList,
260+
cl_adapter::cast<cl_command_queue>(hQueue), hImage->get(), blockingWrite,
261+
Origin, Region, rowPitch, slicePitch, pSrc, numEventsInWaitList,
269262
cl_adapter::cast<const cl_event *>(phEventWaitList),
270263
cl_adapter::cast<cl_event *>(phEvent)));
271264

@@ -283,9 +276,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
283276
const size_t Region[3] = {region.width, region.height, region.depth};
284277

285278
CL_RETURN_ON_FAILURE(clEnqueueCopyImage(
286-
cl_adapter::cast<cl_command_queue>(hQueue),
287-
cl_adapter::cast<cl_mem>(hImageSrc), cl_adapter::cast<cl_mem>(hImageDst),
288-
SrcOrigin, DstOrigin, Region, numEventsInWaitList,
279+
cl_adapter::cast<cl_command_queue>(hQueue), hImageSrc->get(),
280+
hImageDst->get(), SrcOrigin, DstOrigin, Region, numEventsInWaitList,
289281
cl_adapter::cast<const cl_event *>(phEventWaitList),
290282
cl_adapter::cast<cl_event *>(phEvent)));
291283

@@ -300,8 +292,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap(
300292

301293
cl_int Err;
302294
*ppRetMap = clEnqueueMapBuffer(
303-
cl_adapter::cast<cl_command_queue>(hQueue),
304-
cl_adapter::cast<cl_mem>(hBuffer), blockingMap,
295+
cl_adapter::cast<cl_command_queue>(hQueue), hBuffer->get(), blockingMap,
305296
convertURMapFlagsToCL(mapFlags), offset, size, numEventsInWaitList,
306297
cl_adapter::cast<const cl_event *>(phEventWaitList),
307298
cl_adapter::cast<cl_event *>(phEvent), &Err);
@@ -315,9 +306,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
315306
ur_event_handle_t *phEvent) {
316307

317308
CL_RETURN_ON_FAILURE(clEnqueueUnmapMemObject(
318-
cl_adapter::cast<cl_command_queue>(hQueue),
319-
cl_adapter::cast<cl_mem>(hMem), pMappedPtr, numEventsInWaitList,
320-
cl_adapter::cast<const cl_event *>(phEventWaitList),
309+
cl_adapter::cast<cl_command_queue>(hQueue), hMem->get(), pMappedPtr,
310+
numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
321311
cl_adapter::cast<cl_event *>(phEvent)));
322312

323313
return UR_RESULT_SUCCESS;

source/adapters/opencl/memory.cpp

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
//
99
//===----------------------------------------------------------------------===//
1010

11+
#include "memory.hpp"
1112
#include "common.hpp"
1213
#include "context.hpp"
1314

@@ -262,18 +263,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate(
262263
}
263264
PropertiesIntel.push_back(0);
264265

265-
*phBuffer = reinterpret_cast<ur_mem_handle_t>(FuncPtr(
266+
cl_mem Buffer = FuncPtr(
266267
CLContext, PropertiesIntel.data(), static_cast<cl_mem_flags>(flags),
267-
size, pProperties->pHost, cl_adapter::cast<cl_int *>(&RetErr)));
268+
size, pProperties->pHost, cl_adapter::cast<cl_int *>(&RetErr));
269+
*phBuffer = new ur_mem_handle_t_(Buffer, hContext);
268270
return mapCLErrorToUR(RetErr);
269271
}
270272
}
271273

272274
void *HostPtr = pProperties ? pProperties->pHost : nullptr;
273-
*phBuffer = reinterpret_cast<ur_mem_handle_t>(
275+
cl_mem Buffer =
274276
clCreateBuffer(hContext->get(), static_cast<cl_mem_flags>(flags), size,
275-
HostPtr, cl_adapter::cast<cl_int *>(&RetErr)));
277+
HostPtr, cl_adapter::cast<cl_int *>(&RetErr));
276278
CL_RETURN_ON_FAILURE(RetErr);
279+
*phBuffer = new ur_mem_handle_t_(Buffer, hContext);
277280

278281
return UR_RESULT_SUCCESS;
279282
}
@@ -289,10 +292,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemImageCreate(
289292
cl_image_desc ImageDesc = mapURImageDescToCL(pImageDesc);
290293
cl_map_flags MapFlags = convertURMemFlagsToCL(flags);
291294

292-
*phMem = reinterpret_cast<ur_mem_handle_t>(
295+
cl_mem Mem =
293296
clCreateImage(hContext->get(), MapFlags, &ImageFormat, &ImageDesc, pHost,
294-
cl_adapter::cast<cl_int *>(&RetErr)));
297+
cl_adapter::cast<cl_int *>(&RetErr));
295298
CL_RETURN_ON_FAILURE(RetErr);
299+
*phMem = new ur_mem_handle_t_(Mem, hContext);
296300

297301
return UR_RESULT_SUCCESS;
298302
}
@@ -318,14 +322,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferPartition(
318322
BufferRegion.size = pRegion->size;
319323

320324
*phMem = reinterpret_cast<ur_mem_handle_t>(clCreateSubBuffer(
321-
cl_adapter::cast<cl_mem>(hBuffer), static_cast<cl_mem_flags>(flags),
322-
BufferCreateType, &BufferRegion, cl_adapter::cast<cl_int *>(&RetErr)));
325+
hBuffer->get(), static_cast<cl_mem_flags>(flags), BufferCreateType,
326+
&BufferRegion, cl_adapter::cast<cl_int *>(&RetErr)));
323327

324328
if (RetErr == CL_INVALID_VALUE) {
325329
size_t BufferSize = 0;
326-
CL_RETURN_ON_FAILURE(clGetMemObjectInfo(cl_adapter::cast<cl_mem>(hBuffer),
327-
CL_MEM_SIZE, sizeof(BufferSize),
328-
&BufferSize, nullptr));
330+
CL_RETURN_ON_FAILURE(clGetMemObjectInfo(
331+
hBuffer->get(), CL_MEM_SIZE, sizeof(BufferSize), &BufferSize, nullptr));
329332
if (BufferRegion.size + BufferRegion.origin > BufferSize)
330333
return UR_RESULT_ERROR_INVALID_BUFFER_SIZE;
331334
}
@@ -334,27 +337,27 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferPartition(
334337

335338
UR_APIEXPORT ur_result_t UR_APICALL
336339
urMemGetNativeHandle(ur_mem_handle_t hMem, ur_native_handle_t *phNativeMem) {
337-
return getNativeHandle(hMem, phNativeMem);
340+
return getNativeHandle(hMem->get(), phNativeMem);
338341
}
339342

340343
UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreateWithNativeHandle(
341-
ur_native_handle_t hNativeMem,
342-
[[maybe_unused]] ur_context_handle_t hContext,
344+
ur_native_handle_t hNativeMem, ur_context_handle_t hContext,
343345
const ur_mem_native_properties_t *pProperties, ur_mem_handle_t *phMem) {
344-
*phMem = reinterpret_cast<ur_mem_handle_t>(hNativeMem);
346+
cl_mem NativeHandle = reinterpret_cast<cl_mem>(hNativeMem);
347+
*phMem = new ur_mem_handle_t_(NativeHandle, hContext);
345348
if (!pProperties || !pProperties->isNativeHandleOwned) {
346349
return urMemRetain(*phMem);
347350
}
348351
return UR_RESULT_SUCCESS;
349352
}
350353

351354
UR_APIEXPORT ur_result_t UR_APICALL urMemImageCreateWithNativeHandle(
352-
ur_native_handle_t hNativeMem,
353-
[[maybe_unused]] ur_context_handle_t hContext,
355+
ur_native_handle_t hNativeMem, ur_context_handle_t hContext,
354356
[[maybe_unused]] const ur_image_format_t *pImageFormat,
355357
[[maybe_unused]] const ur_image_desc_t *pImageDesc,
356358
const ur_mem_native_properties_t *pProperties, ur_mem_handle_t *phMem) {
357-
*phMem = reinterpret_cast<ur_mem_handle_t>(hNativeMem);
359+
cl_mem NativeHandle = reinterpret_cast<cl_mem>(hNativeMem);
360+
*phMem = new ur_mem_handle_t_(NativeHandle, hContext);
358361
if (!pProperties || !pProperties->isNativeHandleOwned) {
359362
return urMemRetain(*phMem);
360363
}
@@ -370,17 +373,24 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemGetInfo(ur_mem_handle_t hMemory,
370373
UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet);
371374
const cl_int CLPropName = mapURMemInfoToCL(propName);
372375

373-
size_t CheckPropSize = 0;
374-
auto ClResult =
375-
clGetMemObjectInfo(cl_adapter::cast<cl_mem>(hMemory), CLPropName,
376-
propSize, pPropValue, &CheckPropSize);
377-
if (pPropValue && CheckPropSize != propSize) {
378-
return UR_RESULT_ERROR_INVALID_SIZE;
376+
switch (static_cast<uint32_t>(propName)) {
377+
case UR_PROGRAM_INFO_CONTEXT: {
378+
return ReturnValue(hMemory->Context);
379379
}
380-
CL_RETURN_ON_FAILURE(ClResult);
381-
if (pPropSizeRet) {
382-
*pPropSizeRet = CheckPropSize;
380+
default: {
381+
size_t CheckPropSize = 0;
382+
auto ClResult = clGetMemObjectInfo(hMemory->get(), CLPropName, propSize,
383+
pPropValue, &CheckPropSize);
384+
if (pPropValue && CheckPropSize != propSize) {
385+
return UR_RESULT_ERROR_INVALID_SIZE;
386+
}
387+
CL_RETURN_ON_FAILURE(ClResult);
388+
if (pPropSizeRet) {
389+
*pPropSizeRet = CheckPropSize;
390+
}
383391
}
392+
}
393+
384394
return UR_RESULT_SUCCESS;
385395
}
386396

@@ -394,8 +404,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemImageGetInfo(ur_mem_handle_t hMemory,
394404
const cl_int CLPropName = mapURMemImageInfoToCL(propName);
395405

396406
size_t CheckPropSize = 0;
397-
auto ClResult = clGetImageInfo(cl_adapter::cast<cl_mem>(hMemory), CLPropName,
398-
propSize, pPropValue, &CheckPropSize);
407+
auto ClResult = clGetImageInfo(hMemory->get(), CLPropName, propSize,
408+
pPropValue, &CheckPropSize);
399409
if (pPropValue && CheckPropSize != propSize) {
400410
return UR_RESULT_ERROR_INVALID_SIZE;
401411
}
@@ -407,11 +417,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemImageGetInfo(ur_mem_handle_t hMemory,
407417
}
408418

409419
UR_APIEXPORT ur_result_t UR_APICALL urMemRetain(ur_mem_handle_t hMem) {
410-
CL_RETURN_ON_FAILURE(clRetainMemObject(cl_adapter::cast<cl_mem>(hMem)));
420+
CL_RETURN_ON_FAILURE(clRetainMemObject(hMem->get()));
411421
return UR_RESULT_SUCCESS;
412422
}
413423

414424
UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) {
415-
CL_RETURN_ON_FAILURE(clReleaseMemObject(cl_adapter::cast<cl_mem>(hMem)));
425+
CL_RETURN_ON_FAILURE(clReleaseMemObject(hMem->get()));
416426
return UR_RESULT_SUCCESS;
417427
}

0 commit comments

Comments
 (0)