Skip to content

Commit ce492f5

Browse files
committed
[CUDA] Fix potential issue with command buffer fills on CUDA
- Fix a potential issue where decomposed fill nodes for large patterns would overwrite external event dependencies provided by the user when stored in a command handle - Also store decomposed nodes in fill command handles for future use when updating. - Add missing event_sync tests for large pattern fills (> 4 bytes)
1 parent 2739808 commit ce492f5

File tree

4 files changed

+243
-22
lines changed

4 files changed

+243
-22
lines changed

source/adapters/cuda/command_buffer.cpp

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -230,16 +230,29 @@ static ur_result_t enqueueCommandBufferFillHelper(
230230
}
231231

232232
try {
233+
// CUDA has no memset functions that allow setting values more than 4
234+
// bytes. UR API lets you pass an arbitrary "pattern" to the buffer
235+
// fill, which can be more than 4 bytes. Calculate the number of steps
236+
// required here to see if decomposing to multiple fill nodes is required.
237+
size_t NumberOfSteps = PatternSize / sizeof(uint8_t);
238+
233239
// Graph node added to graph, if multiple nodes are created this will
234240
// be set to the leaf node
235241
CUgraphNode GraphNode;
242+
// Track if multiple nodes are created so we can pass them to the command
243+
// handle
244+
std::vector<CUgraphNode> DecomposedNodes;
245+
246+
if (NumberOfSteps > 4) {
247+
DecomposedNodes.reserve(NumberOfSteps);
248+
}
236249

237250
const size_t N = Size / PatternSize;
238251
auto DstPtr = DstType == CU_MEMORYTYPE_DEVICE
239252
? *static_cast<CUdeviceptr *>(DstDevice)
240253
: (CUdeviceptr)DstDevice;
241254

242-
if ((PatternSize == 1) || (PatternSize == 2) || (PatternSize == 4)) {
255+
if (NumberOfSteps <= 4) {
243256
CUDA_MEMSET_NODE_PARAMS NodeParams = {};
244257
NodeParams.dst = DstPtr;
245258
NodeParams.elementSize = PatternSize;
@@ -271,14 +284,9 @@ static ur_result_t enqueueCommandBufferFillHelper(
271284
DepsList.data(), DepsList.size(), &NodeParams,
272285
CommandBuffer->Device->getNativeContext()));
273286
} else {
274-
// CUDA has no memset functions that allow setting values more than 4
275-
// bytes. UR API lets you pass an arbitrary "pattern" to the buffer
276-
// fill, which can be more than 4 bytes. We must break up the pattern
277-
// into 1 byte values, and set the buffer using multiple strided calls.
278-
// This means that one cuGraphAddMemsetNode call is made for every 1
279-
// bytes in the pattern.
280-
281-
size_t NumberOfSteps = PatternSize / sizeof(uint8_t);
287+
// We must break up the rest of the pattern into 1 byte values, and set
288+
// the buffer using multiple strided calls. This means that one
289+
// cuGraphAddMemsetNode call is made for every 1 bytes in the pattern.
282290

283291
// Update NodeParam
284292
CUDA_MEMSET_NODE_PARAMS NodeParamsStepFirst = {};
@@ -289,13 +297,14 @@ static ur_result_t enqueueCommandBufferFillHelper(
289297
NodeParamsStepFirst.value = *static_cast<const uint32_t *>(Pattern);
290298
NodeParamsStepFirst.width = 1;
291299

300+
// Inital decomposed node depends on the provided external event wait
301+
// nodes
292302
UR_CHECK_ERROR(cuGraphAddMemsetNode(
293303
&GraphNode, CommandBuffer->CudaGraph, DepsList.data(),
294304
DepsList.size(), &NodeParamsStepFirst,
295305
CommandBuffer->Device->getNativeContext()));
296306

297-
DepsList.clear();
298-
DepsList.push_back(GraphNode);
307+
DecomposedNodes.push_back(GraphNode);
299308

300309
// we walk up the pattern in 1-byte steps, and call cuMemset for each
301310
// 1-byte chunk of the pattern.
@@ -315,13 +324,16 @@ static ur_result_t enqueueCommandBufferFillHelper(
315324
NodeParamsStep.value = Value;
316325
NodeParamsStep.width = 1;
317326

327+
// Copy the last GraphNode ptr so we can pass it as the dependency for
328+
// the next one
329+
CUgraphNode PrevNode = GraphNode;
330+
318331
UR_CHECK_ERROR(cuGraphAddMemsetNode(
319-
&GraphNode, CommandBuffer->CudaGraph, DepsList.data(),
320-
DepsList.size(), &NodeParamsStep,
332+
&GraphNode, CommandBuffer->CudaGraph, &PrevNode, 1, &NodeParamsStep,
321333
CommandBuffer->Device->getNativeContext()));
322334

323-
DepsList.clear();
324-
DepsList.push_back(GraphNode);
335+
// Store the decomposed node
336+
DecomposedNodes.push_back(GraphNode);
325337
}
326338
}
327339

@@ -340,7 +352,8 @@ static ur_result_t enqueueCommandBufferFillHelper(
340352

341353
std::vector<CUgraphNode> WaitNodes =
342354
NumEventsInWaitList ? DepsList : std::vector<CUgraphNode>();
343-
auto NewCommand = new T(CommandBuffer, GraphNode, SignalNode, WaitNodes);
355+
auto NewCommand = new T(CommandBuffer, GraphNode, SignalNode, WaitNodes,
356+
std::move(DecomposedNodes));
344357
CommandBuffer->CommandHandles.push_back(NewCommand);
345358

346359
if (RetCommand) {

source/adapters/cuda/command_buffer.hpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,18 @@ struct usm_memcpy_command_handle : ur_exp_command_buffer_command_handle_t_ {
172172
struct usm_fill_command_handle : ur_exp_command_buffer_command_handle_t_ {
173173
usm_fill_command_handle(ur_exp_command_buffer_handle_t CommandBuffer,
174174
CUgraphNode Node, CUgraphNode SignalNode,
175-
std::vector<CUgraphNode> WaitNodes)
175+
std::vector<CUgraphNode> WaitNodes,
176+
const std::vector<CUgraphNode> &DecomposedNodes = {})
176177
: ur_exp_command_buffer_command_handle_t_(CommandBuffer, Node, SignalNode,
177-
WaitNodes) {}
178+
WaitNodes),
179+
DecomposedNodes(std::move(DecomposedNodes)) {}
178180
CommandType getCommandType() const noexcept override {
179181
return CommandType::USMFill;
180182
}
183+
184+
// If this fill command was decomposed into multiple nodes, this vector
185+
// contains all of those nodes in the order they were added to the graph.
186+
std::vector<CUgraphNode> DecomposedNodes;
181187
};
182188

183189
struct buffer_copy_command_handle : ur_exp_command_buffer_command_handle_t_ {
@@ -250,14 +256,20 @@ struct buffer_write_rect_command_handle
250256
};
251257

252258
struct buffer_fill_command_handle : ur_exp_command_buffer_command_handle_t_ {
253-
buffer_fill_command_handle(ur_exp_command_buffer_handle_t CommandBuffer,
254-
CUgraphNode Node, CUgraphNode SignalNode,
255-
std::vector<CUgraphNode> WaitNodes)
259+
buffer_fill_command_handle(
260+
ur_exp_command_buffer_handle_t CommandBuffer, CUgraphNode Node,
261+
CUgraphNode SignalNode, std::vector<CUgraphNode> WaitNodes,
262+
const std::vector<CUgraphNode> &DecomposedNodes = {})
256263
: ur_exp_command_buffer_command_handle_t_(CommandBuffer, Node, SignalNode,
257-
WaitNodes) {}
264+
WaitNodes),
265+
DecomposedNodes(std::move(DecomposedNodes)) {}
258266
CommandType getCommandType() const noexcept override {
259267
return CommandType::MemBufferFill;
260268
}
269+
270+
// If this fill command was decomposed into multiple nodes, this vector
271+
// contains all of those nodes in the order they were added to the graph.
272+
std::vector<CUgraphNode> DecomposedNodes;
261273
};
262274

263275
struct usm_prefetch_command_handle : ur_exp_command_buffer_command_handle_t_ {

test/conformance/exp_command_buffer/event_sync.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,42 @@ TEST_P(CommandEventSyncTest, USMFillExp) {
7575
}
7676
}
7777

78+
// Test fill using a large pattern size since implementations may need to handle
79+
// this differently.
80+
TEST_P(CommandEventSyncTest, USMFillLargePatternExp) {
81+
// Device ptrs are allocated in the test fixture with 32-bit values * num
82+
// elements, since we are doubling the pattern size we want to treat those
83+
// device pointers as if they were created with half the number of elements.
84+
constexpr size_t modifiedElementSize = elements / 2;
85+
// Get wait event from queue fill on ptr 0
86+
uint64_t patternX = 42;
87+
ASSERT_SUCCESS(urEnqueueUSMFill(queue, device_ptrs[0], sizeof(patternX),
88+
&patternX, allocation_size, 0, nullptr,
89+
&external_events[0]));
90+
91+
// Test fill command overwriting ptr 0 waiting on queue event
92+
uint64_t patternY = 0xA;
93+
ASSERT_SUCCESS(urCommandBufferAppendUSMFillExp(
94+
cmd_buf_handle, device_ptrs[0], &patternY, sizeof(patternY),
95+
allocation_size, 0, nullptr, 1, &external_events[0], nullptr,
96+
&external_events[1], nullptr));
97+
ASSERT_SUCCESS(urCommandBufferFinalizeExp(cmd_buf_handle));
98+
ASSERT_SUCCESS(
99+
urCommandBufferEnqueueExp(cmd_buf_handle, queue, 0, nullptr, nullptr));
100+
101+
// Queue read ptr 0 based on event returned from command-buffer command
102+
std::array<uint64_t, modifiedElementSize> host_enqueue_ptr{};
103+
ASSERT_SUCCESS(urEnqueueUSMMemcpy(queue, false, host_enqueue_ptr.data(),
104+
device_ptrs[0], allocation_size, 1,
105+
&external_events[1], nullptr));
106+
107+
// Verify
108+
ASSERT_SUCCESS(urQueueFinish(queue));
109+
for (size_t i = 0; i < modifiedElementSize; i++) {
110+
ASSERT_EQ(host_enqueue_ptr[i], patternY);
111+
}
112+
}
113+
78114
TEST_P(CommandEventSyncTest, MemBufferCopyExp) {
79115
// Get wait event from queue fill on buffer 0
80116
uint32_t patternX = 42;
@@ -341,6 +377,42 @@ TEST_P(CommandEventSyncTest, MemBufferFillExp) {
341377
}
342378
}
343379

380+
// Test fill using a large pattern size since implementations may need to handle
381+
// this differently.
382+
TEST_P(CommandEventSyncTest, MemBufferFillLargePatternExp) {
383+
// Device buffers are allocated in the test fixture with 32-bit values * num
384+
// elements, since we are doubling the pattern size we want to treat those
385+
// device pointers as if they were created with half the number of elements.
386+
constexpr size_t modifiedElementSize = elements / 2;
387+
// Get wait event from queue fill on buffer 0
388+
uint64_t patternX = 42;
389+
ASSERT_SUCCESS(urEnqueueMemBufferFill(queue, buffers[0], &patternX,
390+
sizeof(patternX), 0, allocation_size,
391+
0, nullptr, &external_events[0]));
392+
393+
// Test fill command overwriting buffer 0 based on queue event
394+
uint64_t patternY = 0xA;
395+
ASSERT_SUCCESS(urCommandBufferAppendMemBufferFillExp(
396+
cmd_buf_handle, buffers[0], &patternY, sizeof(patternY), 0,
397+
allocation_size, 0, nullptr, 1, &external_events[0], nullptr,
398+
&external_events[1], nullptr));
399+
ASSERT_SUCCESS(urCommandBufferFinalizeExp(cmd_buf_handle));
400+
ASSERT_SUCCESS(
401+
urCommandBufferEnqueueExp(cmd_buf_handle, queue, 0, nullptr, nullptr));
402+
403+
// Queue read buffer 0 based on event returned from command-buffer command
404+
std::array<uint64_t, modifiedElementSize> host_enqueue_ptr{};
405+
ASSERT_SUCCESS(urEnqueueMemBufferRead(
406+
queue, buffers[0], false, 0, allocation_size, host_enqueue_ptr.data(),
407+
1, &external_events[1], nullptr));
408+
409+
// Verify
410+
ASSERT_SUCCESS(urQueueFinish(queue));
411+
for (size_t i = 0; i < modifiedElementSize; i++) {
412+
ASSERT_EQ(host_enqueue_ptr[i], patternY);
413+
}
414+
}
415+
344416
TEST_P(CommandEventSyncTest, USMPrefetchExp) {
345417
// Get wait event from queue fill on ptr 0
346418
uint32_t patternX = 42;

test/conformance/exp_command_buffer/update/event_sync.cpp

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,68 @@ TEST_P(CommandEventSyncUpdateTest, USMFillExp) {
129129
}
130130
}
131131

132+
// Test fill using a large pattern size since implementations may need to handle
133+
// this differently.
134+
TEST_P(CommandEventSyncUpdateTest, USMFillLargePatternExp) {
135+
// Device ptrs are allocated in the test fixture with 32-bit values * num
136+
// elements, since we are doubling the pattern size we want to treat those
137+
// device pointers as if they were created with half the number of elements.
138+
constexpr size_t modifiedElementSize = elements / 2;
139+
// Get wait event from queue fill on ptr 0
140+
uint64_t patternX = 42;
141+
ASSERT_SUCCESS(urEnqueueUSMFill(queue, device_ptrs[0], sizeof(patternX),
142+
&patternX, allocation_size, 0, nullptr,
143+
&external_events[0]));
144+
145+
// Test fill command overwriting ptr 0 waiting on queue event
146+
uint64_t patternY = 0xA;
147+
ASSERT_SUCCESS(urCommandBufferAppendUSMFillExp(
148+
updatable_cmd_buf_handle, device_ptrs[0], &patternY, sizeof(patternY),
149+
allocation_size, 0, nullptr, 1, &external_events[0], nullptr,
150+
&external_events[1], &command_handles[0]));
151+
ASSERT_NE(nullptr, command_handles[0]);
152+
ASSERT_SUCCESS(urCommandBufferFinalizeExp(updatable_cmd_buf_handle));
153+
ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0,
154+
nullptr, nullptr));
155+
156+
// Queue read ptr 0 based on event returned from command-buffer command
157+
std::array<uint64_t, modifiedElementSize> host_enqueue_ptr{};
158+
ASSERT_SUCCESS(urEnqueueUSMMemcpy(queue, false, host_enqueue_ptr.data(),
159+
device_ptrs[0], allocation_size, 1,
160+
&external_events[1], nullptr));
161+
162+
// Verify
163+
ASSERT_SUCCESS(urQueueFinish(queue));
164+
for (size_t i = 0; i < modifiedElementSize; i++) {
165+
ASSERT_EQ(host_enqueue_ptr[i], patternY);
166+
}
167+
168+
uint64_t patternZ = 666;
169+
ASSERT_SUCCESS(urEnqueueUSMFill(queue, device_ptrs[0], sizeof(patternZ),
170+
&patternZ, allocation_size, 0, nullptr,
171+
&external_events[2]));
172+
173+
// Update command command-wait event to wait on fill of new value
174+
ASSERT_SUCCESS(urCommandBufferUpdateWaitEventsExp(command_handles[0], 1,
175+
&external_events[2]));
176+
177+
// Get a new signal event for command-buffer
178+
ASSERT_SUCCESS(urCommandBufferUpdateSignalEventExp(command_handles[0],
179+
&external_events[3]));
180+
ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0,
181+
nullptr, nullptr));
182+
183+
ASSERT_SUCCESS(urEnqueueUSMMemcpy(queue, false, host_enqueue_ptr.data(),
184+
device_ptrs[0], allocation_size, 1,
185+
&external_events[3], nullptr));
186+
187+
// Verify update
188+
ASSERT_SUCCESS(urQueueFinish(queue));
189+
for (size_t i = 0; i < modifiedElementSize; i++) {
190+
ASSERT_EQ(host_enqueue_ptr[i], patternY);
191+
}
192+
}
193+
132194
TEST_P(CommandEventSyncUpdateTest, MemBufferCopyExp) {
133195
// Get wait event from queue fill on buffer 0
134196
uint32_t patternX = 42;
@@ -532,6 +594,68 @@ TEST_P(CommandEventSyncUpdateTest, MemBufferWriteRectExp) {
532594
}
533595
}
534596

597+
// Test fill using a large pattern size since implementations may need to handle
598+
// this differently.
599+
TEST_P(CommandEventSyncUpdateTest, MemBufferFillLargePatternExp) {
600+
// Device buffers are allocated in the test fixture with 32-bit values * num
601+
// elements, since we are doubling the pattern size we want to treat those
602+
// device pointers as if they were created with half the number of elements.
603+
constexpr size_t modifiedElementSize = elements / 2;
604+
// Get wait event from queue fill on buffer 0
605+
uint64_t patternX = 42;
606+
ASSERT_SUCCESS(urEnqueueMemBufferFill(queue, buffers[0], &patternX,
607+
sizeof(patternX), 0, allocation_size,
608+
0, nullptr, &external_events[0]));
609+
610+
// Test fill command overwriting buffer 0 based on queue event
611+
uint64_t patternY = 0xA;
612+
ASSERT_SUCCESS(urCommandBufferAppendMemBufferFillExp(
613+
updatable_cmd_buf_handle, buffers[0], &patternY, sizeof(patternY), 0,
614+
allocation_size, 0, nullptr, 1, &external_events[0], nullptr,
615+
&external_events[1], &command_handles[0]));
616+
ASSERT_NE(nullptr, command_handles[0]);
617+
ASSERT_SUCCESS(urCommandBufferFinalizeExp(updatable_cmd_buf_handle));
618+
ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0,
619+
nullptr, nullptr));
620+
621+
// Queue read buffer 0 based on event returned from command-buffer command
622+
std::array<uint64_t, modifiedElementSize> host_enqueue_ptr{};
623+
ASSERT_SUCCESS(urEnqueueMemBufferRead(
624+
queue, buffers[0], false, 0, allocation_size, host_enqueue_ptr.data(),
625+
1, &external_events[1], nullptr));
626+
627+
// Verify
628+
ASSERT_SUCCESS(urQueueFinish(queue));
629+
for (size_t i = 0; i < modifiedElementSize; i++) {
630+
ASSERT_EQ(host_enqueue_ptr[i], patternY);
631+
}
632+
633+
uint64_t patternZ = 666;
634+
ASSERT_SUCCESS(urEnqueueMemBufferFill(queue, buffers[0], &patternZ,
635+
sizeof(patternZ), 0, allocation_size,
636+
0, nullptr, &external_events[2]));
637+
638+
// Update command command-wait event to wait on fill of new value
639+
ASSERT_SUCCESS(urCommandBufferUpdateWaitEventsExp(command_handles[0], 1,
640+
&external_events[2]));
641+
642+
// Get a new signal event for command-buffer
643+
ASSERT_SUCCESS(urCommandBufferUpdateSignalEventExp(command_handles[0],
644+
&external_events[3]));
645+
646+
ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0,
647+
nullptr, nullptr));
648+
ASSERT_SUCCESS(urEnqueueMemBufferRead(
649+
queue, buffers[0], false, 0, allocation_size, host_enqueue_ptr.data(),
650+
1, &external_events[3], nullptr));
651+
652+
// Verify update
653+
ASSERT_SUCCESS(urQueueFinish(queue));
654+
for (size_t i = 0; i < modifiedElementSize; i++) {
655+
ASSERT_EQ(host_enqueue_ptr[i], patternY);
656+
}
657+
}
658+
535659
TEST_P(CommandEventSyncUpdateTest, MemBufferFillExp) {
536660
// Get wait event from queue fill on buffer 0
537661
uint32_t patternX = 42;

0 commit comments

Comments
 (0)