|
11 | 11 | #include <algorithm>
|
12 | 12 | #include <climits>
|
13 | 13 | #include <string.h>
|
| 14 | +#include <ur/ur.hpp> |
14 | 15 |
|
15 | 16 | #include "context.hpp"
|
16 | 17 | #include "event.hpp"
|
@@ -183,9 +184,6 @@ static ur_result_t enqueueMemFillHelper(ur_command_t CommandType,
|
183 | 184 | uint32_t NumEventsInWaitList,
|
184 | 185 | const ur_event_handle_t *EventWaitList,
|
185 | 186 | ur_event_handle_t *OutEvent) {
|
186 |
| - // Pattern size must be a power of two. |
187 |
| - UR_ASSERT((PatternSize > 0) && ((PatternSize & (PatternSize - 1)) == 0), |
188 |
| - UR_RESULT_ERROR_INVALID_VALUE); |
189 | 187 | auto &Device = Queue->Device;
|
190 | 188 |
|
191 | 189 | // Make sure that pattern size matches the capability of the copy queues.
|
@@ -237,18 +235,42 @@ static ur_result_t enqueueMemFillHelper(ur_command_t CommandType,
|
237 | 235 | const auto &ZeCommandList = CommandList->first;
|
238 | 236 | const auto &WaitList = (*Event)->WaitList;
|
239 | 237 |
|
240 |
| - ZE2UR_CALL(zeCommandListAppendMemoryFill, |
241 |
| - (ZeCommandList, Ptr, Pattern, PatternSize, Size, ZeEvent, |
242 |
| - WaitList.Length, WaitList.ZeEventList)); |
| 238 | + // PatternSize must be a power of two for zeCommandListAppendMemoryFill. |
| 239 | + // When it's not, the fill is emulated with zeCommandListAppendMemoryCopy. |
| 240 | + if (isPowerOf2(PatternSize)) { |
| 241 | + ZE2UR_CALL(zeCommandListAppendMemoryFill, |
| 242 | + (ZeCommandList, Ptr, Pattern, PatternSize, Size, ZeEvent, |
| 243 | + WaitList.Length, WaitList.ZeEventList)); |
243 | 244 |
|
244 |
| - logger::debug("calling zeCommandListAppendMemoryFill() with" |
245 |
| - " ZeEvent {}", |
246 |
| - ur_cast<uint64_t>(ZeEvent)); |
247 |
| - printZeEventList(WaitList); |
| 245 | + logger::debug("calling zeCommandListAppendMemoryFill() with" |
| 246 | + " ZeEvent {}", |
| 247 | + ur_cast<uint64_t>(ZeEvent)); |
| 248 | + printZeEventList(WaitList); |
248 | 249 |
|
249 |
| - // Execute command list asynchronously, as the event will be used |
250 |
| - // to track down its completion. |
251 |
| - UR_CALL(Queue->executeCommandList(CommandList, false, OkToBatch)); |
| 250 | + // Execute command list asynchronously, as the event will be used |
| 251 | + // to track down its completion. |
| 252 | + UR_CALL(Queue->executeCommandList(CommandList, false, OkToBatch)); |
| 253 | + } else { |
| 254 | + // Copy pattern into every entry in memory array pointed by Ptr. |
| 255 | + uint32_t NumOfCopySteps = Size / PatternSize; |
| 256 | + const void *Src = Pattern; |
| 257 | + |
| 258 | + for (uint32_t step = 0; step < NumOfCopySteps; ++step) { |
| 259 | + void *Dst = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(Ptr) + |
| 260 | + step * PatternSize); |
| 261 | + ZE2UR_CALL(zeCommandListAppendMemoryCopy, |
| 262 | + (ZeCommandList, Dst, Src, PatternSize, ZeEvent, |
| 263 | + WaitList.Length, WaitList.ZeEventList)); |
| 264 | + } |
| 265 | + |
| 266 | + logger::debug("calling zeCommandListAppendMemoryCopy() with" |
| 267 | + " ZeEvent {}", |
| 268 | + ur_cast<uint64_t>(ZeEvent)); |
| 269 | + printZeEventList(WaitList); |
| 270 | + |
| 271 | + // Execute command list synchronously. |
| 272 | + UR_CALL(Queue->executeCommandList(CommandList, true, OkToBatch)); |
| 273 | + } |
252 | 274 |
|
253 | 275 | return UR_RESULT_SUCCESS;
|
254 | 276 | }
|
|
0 commit comments