@@ -25,6 +25,77 @@ cl_map_flags convertURMapFlagsToCL(ur_map_flags_t URFlags) {
2525 return CLFlags;
2626}
2727
28+ ur_result_t ValidateBufferSize (ur_mem_handle_t Buffer, size_t Size,
29+ size_t Origin) {
30+ size_t BufferSize = 0 ;
31+ CL_RETURN_ON_FAILURE (clGetMemObjectInfo (cl_adapter::cast<cl_mem>(Buffer),
32+ CL_MEM_SIZE, sizeof (BufferSize),
33+ &BufferSize, nullptr ));
34+ if (Size + Origin > BufferSize)
35+ return UR_RESULT_ERROR_INVALID_SIZE;
36+ return UR_RESULT_SUCCESS;
37+ }
38+
39+ ur_result_t ValidateBufferRectSize (ur_mem_handle_t Buffer,
40+ ur_rect_region_t Region,
41+ ur_rect_offset_t Offset) {
42+ size_t BufferSize = 0 ;
43+ CL_RETURN_ON_FAILURE (clGetMemObjectInfo (cl_adapter::cast<cl_mem>(Buffer),
44+ CL_MEM_SIZE, sizeof (BufferSize),
45+ &BufferSize, nullptr ));
46+ if (Offset.x >= BufferSize || Offset.y >= BufferSize ||
47+ Offset.z >= BufferSize) {
48+ return UR_RESULT_ERROR_INVALID_SIZE;
49+ }
50+
51+ if ((Region.width + Offset.x ) * (Region.height + Offset.y ) *
52+ (Region.depth + Offset.z ) >
53+ BufferSize) {
54+ return UR_RESULT_ERROR_INVALID_SIZE;
55+ }
56+
57+ return UR_RESULT_SUCCESS;
58+ }
59+
60+ ur_result_t ValidateImageSize (ur_mem_handle_t Image, ur_rect_region_t Region,
61+ ur_rect_offset_t Origin) {
62+ size_t Width = 0 ;
63+ CL_RETURN_ON_FAILURE (clGetImageInfo (cl_adapter::cast<cl_mem>(Image),
64+ CL_IMAGE_WIDTH, sizeof (Width), &Width,
65+ nullptr ));
66+ if (Region.width + Origin.x > Width) {
67+ return UR_RESULT_ERROR_INVALID_SIZE;
68+ }
69+
70+ size_t Height = 0 ;
71+ CL_RETURN_ON_FAILURE (clGetImageInfo (cl_adapter::cast<cl_mem>(Image),
72+ CL_IMAGE_HEIGHT, sizeof (Height), &Height,
73+ nullptr ));
74+
75+ // CL returns a height and depth of 0 for images that don't have those
76+ // dimensions, but regions for enqueue operations must set these to 1, so we
77+ // need to make this adjustment to validate.
78+ if (Height == 0 )
79+ Height = 1 ;
80+
81+ if (Region.height + Origin.y > Height) {
82+ return UR_RESULT_ERROR_INVALID_SIZE;
83+ }
84+
85+ size_t Depth = 0 ;
86+ CL_RETURN_ON_FAILURE (clGetImageInfo (cl_adapter::cast<cl_mem>(Image),
87+ CL_IMAGE_DEPTH, sizeof (Depth), &Depth,
88+ nullptr ));
89+ if (Depth == 0 )
90+ Depth = 1 ;
91+
92+ if (Region.depth + Origin.z > Depth) {
93+ return UR_RESULT_ERROR_INVALID_SIZE;
94+ }
95+
96+ return UR_RESULT_SUCCESS;
97+ }
98+
2899UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch (
29100 ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
30101 const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
@@ -70,27 +141,33 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
70141 size_t offset, size_t size, void *pDst, uint32_t numEventsInWaitList,
71142 const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
72143
73- CL_RETURN_ON_FAILURE ( clEnqueueReadBuffer (
144+ auto ClErr = clEnqueueReadBuffer (
74145 cl_adapter::cast<cl_command_queue>(hQueue),
75146 cl_adapter::cast<cl_mem>(hBuffer), blockingRead, offset, size, pDst,
76147 numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
77- cl_adapter::cast<cl_event *>(phEvent))) ;
148+ cl_adapter::cast<cl_event *>(phEvent));
78149
79- return UR_RESULT_SUCCESS;
150+ if (ClErr == CL_INVALID_VALUE) {
151+ UR_RETURN_ON_FAILURE (ValidateBufferSize (hBuffer, size, offset));
152+ }
153+ return mapCLErrorToUR (ClErr);
80154}
81155
82156UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite (
83157 ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, bool blockingWrite,
84158 size_t offset, size_t size, const void *pSrc, uint32_t numEventsInWaitList,
85159 const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
86160
87- CL_RETURN_ON_FAILURE ( clEnqueueWriteBuffer (
161+ auto ClErr = clEnqueueWriteBuffer (
88162 cl_adapter::cast<cl_command_queue>(hQueue),
89163 cl_adapter::cast<cl_mem>(hBuffer), blockingWrite, offset, size, pSrc,
90164 numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
91- cl_adapter::cast<cl_event *>(phEvent))) ;
165+ cl_adapter::cast<cl_event *>(phEvent));
92166
93- return UR_RESULT_SUCCESS;
167+ if (ClErr == CL_INVALID_VALUE) {
168+ UR_RETURN_ON_FAILURE (ValidateBufferSize (hBuffer, size, offset));
169+ }
170+ return mapCLErrorToUR (ClErr);
94171}
95172
96173UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect (
@@ -101,17 +178,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
101178 uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
102179 ur_event_handle_t *phEvent) {
103180
104- CL_RETURN_ON_FAILURE ( clEnqueueReadBufferRect (
181+ auto ClErr = clEnqueueReadBufferRect (
105182 cl_adapter::cast<cl_command_queue>(hQueue),
106183 cl_adapter::cast<cl_mem>(hBuffer), blockingRead,
107184 cl_adapter::cast<const size_t *>(&bufferOrigin),
108185 cl_adapter::cast<const size_t *>(&hostOrigin),
109186 cl_adapter::cast<const size_t *>(®ion), bufferRowPitch,
110187 bufferSlicePitch, hostRowPitch, hostSlicePitch, pDst, numEventsInWaitList,
111188 cl_adapter::cast<const cl_event *>(phEventWaitList),
112- cl_adapter::cast<cl_event *>(phEvent))) ;
189+ cl_adapter::cast<cl_event *>(phEvent));
113190
114- return UR_RESULT_SUCCESS;
191+ if (ClErr == CL_INVALID_VALUE) {
192+ UR_RETURN_ON_FAILURE (ValidateBufferRectSize (hBuffer, region, bufferOrigin));
193+ }
194+ return mapCLErrorToUR (ClErr);
115195}
116196
117197UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect (
@@ -122,17 +202,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
122202 uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
123203 ur_event_handle_t *phEvent) {
124204
125- CL_RETURN_ON_FAILURE ( clEnqueueWriteBufferRect (
205+ auto ClErr = clEnqueueWriteBufferRect (
126206 cl_adapter::cast<cl_command_queue>(hQueue),
127207 cl_adapter::cast<cl_mem>(hBuffer), blockingWrite,
128208 cl_adapter::cast<const size_t *>(&bufferOrigin),
129209 cl_adapter::cast<const size_t *>(&hostOrigin),
130210 cl_adapter::cast<const size_t *>(®ion), bufferRowPitch,
131211 bufferSlicePitch, hostRowPitch, hostSlicePitch, pSrc, numEventsInWaitList,
132212 cl_adapter::cast<const cl_event *>(phEventWaitList),
133- cl_adapter::cast<cl_event *>(phEvent))) ;
213+ cl_adapter::cast<cl_event *>(phEvent));
134214
135- return UR_RESULT_SUCCESS;
215+ if (ClErr == CL_INVALID_VALUE) {
216+ UR_RETURN_ON_FAILURE (ValidateBufferRectSize (hBuffer, region, bufferOrigin));
217+ }
218+ return mapCLErrorToUR (ClErr);
136219}
137220
138221UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy (
@@ -141,14 +224,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
141224 uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
142225 ur_event_handle_t *phEvent) {
143226
144- CL_RETURN_ON_FAILURE ( clEnqueueCopyBuffer (
227+ auto ClErr = clEnqueueCopyBuffer (
145228 cl_adapter::cast<cl_command_queue>(hQueue),
146229 cl_adapter::cast<cl_mem>(hBufferSrc),
147230 cl_adapter::cast<cl_mem>(hBufferDst), srcOffset, dstOffset, size,
148231 numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
149- cl_adapter::cast<cl_event *>(phEvent))) ;
232+ cl_adapter::cast<cl_event *>(phEvent));
150233
151- return UR_RESULT_SUCCESS;
234+ if (ClErr == CL_INVALID_VALUE) {
235+ UR_RETURN_ON_FAILURE (ValidateBufferSize (hBufferSrc, size, srcOffset));
236+ UR_RETURN_ON_FAILURE (ValidateBufferSize (hBufferDst, size, dstOffset));
237+ }
238+ return mapCLErrorToUR (ClErr);
152239}
153240
154241UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect (
@@ -159,7 +246,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
159246 uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
160247 ur_event_handle_t *phEvent) {
161248
162- CL_RETURN_ON_FAILURE ( clEnqueueCopyBufferRect (
249+ auto ClErr = clEnqueueCopyBufferRect (
163250 cl_adapter::cast<cl_command_queue>(hQueue),
164251 cl_adapter::cast<cl_mem>(hBufferSrc),
165252 cl_adapter::cast<cl_mem>(hBufferDst),
@@ -168,9 +255,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
168255 cl_adapter::cast<const size_t *>(®ion), srcRowPitch, srcSlicePitch,
169256 dstRowPitch, dstSlicePitch, numEventsInWaitList,
170257 cl_adapter::cast<const cl_event *>(phEventWaitList),
171- cl_adapter::cast<cl_event *>(phEvent))) ;
258+ cl_adapter::cast<cl_event *>(phEvent));
172259
173- return UR_RESULT_SUCCESS;
260+ if (ClErr == CL_INVALID_VALUE) {
261+ UR_RETURN_ON_FAILURE (ValidateBufferRectSize (hBufferSrc, region, srcOrigin));
262+ UR_RETURN_ON_FAILURE (ValidateBufferRectSize (hBufferDst, region, dstOrigin));
263+ }
264+ return mapCLErrorToUR (ClErr);
174265}
175266
176267UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill (
@@ -181,13 +272,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
181272 // CL FillBuffer only allows pattern sizes up to the largest CL type:
182273 // long16/double16
183274 if (patternSize <= 128 ) {
184- CL_RETURN_ON_FAILURE (
185- clEnqueueFillBuffer (cl_adapter::cast<cl_command_queue>(hQueue),
186- cl_adapter::cast<cl_mem>(hBuffer), pPattern,
187- patternSize, offset, size, numEventsInWaitList,
188- cl_adapter::cast<const cl_event *>(phEventWaitList),
189- cl_adapter::cast<cl_event *>(phEvent)));
190- return UR_RESULT_SUCCESS;
275+ auto ClErr = (clEnqueueFillBuffer (
276+ cl_adapter::cast<cl_command_queue>(hQueue),
277+ cl_adapter::cast<cl_mem>(hBuffer), pPattern, patternSize, offset, size,
278+ numEventsInWaitList,
279+ cl_adapter::cast<const cl_event *>(phEventWaitList),
280+ cl_adapter::cast<cl_event *>(phEvent)));
281+ if (ClErr != CL_SUCCESS) {
282+ UR_RETURN_ON_FAILURE (ValidateBufferSize (hBuffer, size, offset));
283+ }
284+ return mapCLErrorToUR (ClErr);
191285 }
192286
193287 auto NumValues = size / sizeof (uint64_t );
@@ -205,6 +299,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
205299 &WriteEvent);
206300 if (ClErr != CL_SUCCESS) {
207301 delete[] HostBuffer;
302+ UR_RETURN_ON_FAILURE (ValidateBufferSize (hBuffer, offset, size));
208303 CL_RETURN_ON_FAILURE (ClErr);
209304 }
210305
@@ -237,15 +332,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
237332 size_t slicePitch, void *pDst, uint32_t numEventsInWaitList,
238333 const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
239334
240- CL_RETURN_ON_FAILURE ( clEnqueueReadImage (
335+ auto ClErr = clEnqueueReadImage (
241336 cl_adapter::cast<cl_command_queue>(hQueue),
242337 cl_adapter::cast<cl_mem>(hImage), blockingRead,
243338 cl_adapter::cast<const size_t *>(&origin),
244339 cl_adapter::cast<const size_t *>(®ion), rowPitch, slicePitch, pDst,
245340 numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
246- cl_adapter::cast<cl_event *>(phEvent))) ;
341+ cl_adapter::cast<cl_event *>(phEvent));
247342
248- return UR_RESULT_SUCCESS;
343+ if (ClErr == CL_INVALID_VALUE) {
344+ UR_RETURN_ON_FAILURE (ValidateImageSize (hImage, region, origin));
345+ }
346+ return mapCLErrorToUR (ClErr);
249347}
250348
251349UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite (
@@ -254,15 +352,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite(
254352 size_t slicePitch, void *pSrc, uint32_t numEventsInWaitList,
255353 const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
256354
257- CL_RETURN_ON_FAILURE ( clEnqueueWriteImage (
355+ auto ClErr = clEnqueueWriteImage (
258356 cl_adapter::cast<cl_command_queue>(hQueue),
259357 cl_adapter::cast<cl_mem>(hImage), blockingWrite,
260358 cl_adapter::cast<const size_t *>(&origin),
261359 cl_adapter::cast<const size_t *>(®ion), rowPitch, slicePitch, pSrc,
262360 numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
263- cl_adapter::cast<cl_event *>(phEvent))) ;
361+ cl_adapter::cast<cl_event *>(phEvent));
264362
265- return UR_RESULT_SUCCESS;
363+ if (ClErr == CL_INVALID_VALUE) {
364+ UR_RETURN_ON_FAILURE (ValidateImageSize (hImage, region, origin));
365+ }
366+ return mapCLErrorToUR (ClErr);
266367}
267368
268369UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy (
@@ -272,16 +373,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
272373 uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
273374 ur_event_handle_t *phEvent) {
274375
275- CL_RETURN_ON_FAILURE ( clEnqueueCopyImage (
376+ auto ClErr = clEnqueueCopyImage (
276377 cl_adapter::cast<cl_command_queue>(hQueue),
277378 cl_adapter::cast<cl_mem>(hImageSrc), cl_adapter::cast<cl_mem>(hImageDst),
278379 cl_adapter::cast<const size_t *>(&srcOrigin),
279380 cl_adapter::cast<const size_t *>(&dstOrigin),
280381 cl_adapter::cast<const size_t *>(®ion), numEventsInWaitList,
281382 cl_adapter::cast<const cl_event *>(phEventWaitList),
282- cl_adapter::cast<cl_event *>(phEvent))) ;
383+ cl_adapter::cast<cl_event *>(phEvent));
283384
284- return UR_RESULT_SUCCESS;
385+ if (ClErr == CL_INVALID_VALUE) {
386+ UR_RETURN_ON_FAILURE (ValidateImageSize (hImageSrc, region, srcOrigin));
387+ UR_RETURN_ON_FAILURE (ValidateImageSize (hImageDst, region, dstOrigin));
388+ }
389+ return mapCLErrorToUR (ClErr);
285390}
286391
287392UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap (
@@ -298,9 +403,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap(
298403 cl_adapter::cast<const cl_event *>(phEventWaitList),
299404 cl_adapter::cast<cl_event *>(phEvent), &Err);
300405
301- CL_RETURN_ON_FAILURE (Err);
302-
303- return UR_RESULT_SUCCESS;
406+ if (Err == CL_INVALID_VALUE) {
407+ UR_RETURN_ON_FAILURE (ValidateBufferSize (hBuffer, size, offset));
408+ }
409+ return mapCLErrorToUR (Err);
304410}
305411
306412UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap (
0 commit comments