Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 29 additions & 16 deletions source/adapters/cuda/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,16 +236,29 @@ static ur_result_t enqueueCommandBufferFillHelper(
EventWaitList));
}

// CUDA has no memset functions that allow setting values more than 4
// bytes. UR API lets you pass an arbitrary "pattern" to the buffer
// fill, which can be more than 4 bytes. Calculate the number of steps
// required here to see if decomposing to multiple fill nodes is required.
size_t NumberOfSteps = PatternSize / sizeof(uint8_t);

// Graph node added to graph, if multiple nodes are created this will
// be set to the leaf node
CUgraphNode GraphNode;
// Track if multiple nodes are created so we can pass them to the command
// handle
std::vector<CUgraphNode> DecomposedNodes;

if (NumberOfSteps > 4) {
DecomposedNodes.reserve(NumberOfSteps);
}

const size_t N = Size / PatternSize;
auto DstPtr = DstType == CU_MEMORYTYPE_DEVICE
? *static_cast<CUdeviceptr *>(DstDevice)
: (CUdeviceptr)DstDevice;

if ((PatternSize == 1) || (PatternSize == 2) || (PatternSize == 4)) {
if (NumberOfSteps <= 4) {
CUDA_MEMSET_NODE_PARAMS NodeParams = {};
NodeParams.dst = DstPtr;
NodeParams.elementSize = PatternSize;
Expand Down Expand Up @@ -276,14 +289,9 @@ static ur_result_t enqueueCommandBufferFillHelper(
&GraphNode, CommandBuffer->CudaGraph, DepsList.data(), DepsList.size(),
&NodeParams, CommandBuffer->Device->getNativeContext()));
} else {
// CUDA has no memset functions that allow setting values more than 4
// bytes. UR API lets you pass an arbitrary "pattern" to the buffer
// fill, which can be more than 4 bytes. We must break up the pattern
// into 1 byte values, and set the buffer using multiple strided calls.
// This means that one cuGraphAddMemsetNode call is made for every 1
// bytes in the pattern.

size_t NumberOfSteps = PatternSize / sizeof(uint8_t);
// We must break up the rest of the pattern into 1 byte values, and set
// the buffer using multiple strided calls. This means that one
// cuGraphAddMemsetNode call is made for every 1 bytes in the pattern.

// Update NodeParam
CUDA_MEMSET_NODE_PARAMS NodeParamsStepFirst = {};
Expand All @@ -294,12 +302,13 @@ static ur_result_t enqueueCommandBufferFillHelper(
NodeParamsStepFirst.value = *static_cast<const uint32_t *>(Pattern);
NodeParamsStepFirst.width = 1;

// Inital decomposed node depends on the provided external event wait
// nodes
UR_CHECK_ERROR(cuGraphAddMemsetNode(
&GraphNode, CommandBuffer->CudaGraph, DepsList.data(), DepsList.size(),
&NodeParamsStepFirst, CommandBuffer->Device->getNativeContext()));

DepsList.clear();
DepsList.push_back(GraphNode);
DecomposedNodes.push_back(GraphNode);

// we walk up the pattern in 1-byte steps, and call cuMemset for each
// 1-byte chunk of the pattern.
Expand All @@ -319,13 +328,16 @@ static ur_result_t enqueueCommandBufferFillHelper(
NodeParamsStep.value = Value;
NodeParamsStep.width = 1;

// Copy the last GraphNode ptr so we can pass it as the dependency for
// the next one
CUgraphNode PrevNode = GraphNode;

UR_CHECK_ERROR(cuGraphAddMemsetNode(
&GraphNode, CommandBuffer->CudaGraph, DepsList.data(),
DepsList.size(), &NodeParamsStep,
&GraphNode, CommandBuffer->CudaGraph, &PrevNode, 1, &NodeParamsStep,
CommandBuffer->Device->getNativeContext()));

DepsList.clear();
DepsList.push_back(GraphNode);
// Store the decomposed node
DecomposedNodes.push_back(GraphNode);
}
}

Expand All @@ -344,7 +356,8 @@ static ur_result_t enqueueCommandBufferFillHelper(

std::vector<CUgraphNode> WaitNodes =
NumEventsInWaitList ? std::move(DepsList) : std::vector<CUgraphNode>();
auto NewCommand = new T(CommandBuffer, GraphNode, SignalNode, WaitNodes);
auto NewCommand = new T(CommandBuffer, GraphNode, SignalNode, WaitNodes,
std::move(DecomposedNodes));
CommandBuffer->CommandHandles.push_back(NewCommand);

if (RetCommand) {
Expand Down
26 changes: 20 additions & 6 deletions source/adapters/cuda/command_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,19 @@ struct usm_memcpy_command_handle : ur_exp_command_buffer_command_handle_t_ {
struct usm_fill_command_handle : ur_exp_command_buffer_command_handle_t_ {
usm_fill_command_handle(ur_exp_command_buffer_handle_t CommandBuffer,
CUgraphNode Node, CUgraphNode SignalNode,
const std::vector<CUgraphNode> &WaitNodes)
const std::vector<CUgraphNode> &WaitNodes,
const std::vector<CUgraphNode> &DecomposedNodes = {})
: ur_exp_command_buffer_command_handle_t_(CommandBuffer, Node, SignalNode,
WaitNodes) {}
WaitNodes),
DecomposedNodes(std::move(DecomposedNodes)) {}
CommandType getCommandType() const noexcept override {
return CommandType::USMFill;
}

// If this fill command was decomposed into multiple nodes, this vector
// contains all of those nodes in the order they were added to the graph.
// Currently unused but will be required for updating in future.
std::vector<CUgraphNode> DecomposedNodes;
};

struct buffer_copy_command_handle : ur_exp_command_buffer_command_handle_t_ {
Expand Down Expand Up @@ -250,14 +257,21 @@ struct buffer_write_rect_command_handle
};

struct buffer_fill_command_handle : ur_exp_command_buffer_command_handle_t_ {
buffer_fill_command_handle(ur_exp_command_buffer_handle_t CommandBuffer,
CUgraphNode Node, CUgraphNode SignalNode,
const std::vector<CUgraphNode> &WaitNodes)
buffer_fill_command_handle(
ur_exp_command_buffer_handle_t CommandBuffer, CUgraphNode Node,
CUgraphNode SignalNode, const std::vector<CUgraphNode> &WaitNodes,
const std::vector<CUgraphNode> &DecomposedNodes = {})
: ur_exp_command_buffer_command_handle_t_(CommandBuffer, Node, SignalNode,
WaitNodes) {}
WaitNodes),
DecomposedNodes(std::move(DecomposedNodes)) {}
CommandType getCommandType() const noexcept override {
return CommandType::MemBufferFill;
}

// If this fill command was decomposed into multiple nodes, this vector
// contains all of those nodes in the order they were added to the graph.
// Currently unused but will be required for updating in future.
std::vector<CUgraphNode> DecomposedNodes;
};

struct usm_prefetch_command_handle : ur_exp_command_buffer_command_handle_t_ {
Expand Down
72 changes: 72 additions & 0 deletions test/conformance/exp_command_buffer/event_sync.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,42 @@ TEST_P(CommandEventSyncTest, USMFillExp) {
}
}

// Test fill using a large pattern size since implementations may need to handle
// this differently.
TEST_P(CommandEventSyncTest, USMFillLargePatternExp) {
// Device ptrs are allocated in the test fixture with 32-bit values * num
// elements, since we are doubling the pattern size we want to treat those
// device pointers as if they were created with half the number of elements.
constexpr size_t modifiedElementSize = elements / 2;
// Get wait event from queue fill on ptr 0
uint64_t patternX = 42;
ASSERT_SUCCESS(urEnqueueUSMFill(queue, device_ptrs[0], sizeof(patternX),
&patternX, allocation_size, 0, nullptr,
&external_events[0]));

// Test fill command overwriting ptr 0 waiting on queue event
uint64_t patternY = 0xA;
ASSERT_SUCCESS(urCommandBufferAppendUSMFillExp(
cmd_buf_handle, device_ptrs[0], &patternY, sizeof(patternY),
allocation_size, 0, nullptr, 1, &external_events[0], nullptr,
&external_events[1], nullptr));
ASSERT_SUCCESS(urCommandBufferFinalizeExp(cmd_buf_handle));
ASSERT_SUCCESS(
urCommandBufferEnqueueExp(cmd_buf_handle, queue, 0, nullptr, nullptr));

// Queue read ptr 0 based on event returned from command-buffer command
std::array<uint64_t, modifiedElementSize> host_enqueue_ptr{};
ASSERT_SUCCESS(urEnqueueUSMMemcpy(queue, false, host_enqueue_ptr.data(),
device_ptrs[0], allocation_size, 1,
&external_events[1], nullptr));

// Verify
ASSERT_SUCCESS(urQueueFinish(queue));
for (size_t i = 0; i < modifiedElementSize; i++) {
ASSERT_EQ(host_enqueue_ptr[i], patternY);
}
}

TEST_P(CommandEventSyncTest, MemBufferCopyExp) {
// Get wait event from queue fill on buffer 0
uint32_t patternX = 42;
Expand Down Expand Up @@ -341,6 +377,42 @@ TEST_P(CommandEventSyncTest, MemBufferFillExp) {
}
}

// Test fill using a large pattern size since implementations may need to handle
// this differently.
TEST_P(CommandEventSyncTest, MemBufferFillLargePatternExp) {
// Device buffers are allocated in the test fixture with 32-bit values * num
// elements, since we are doubling the pattern size we want to treat those
// device pointers as if they were created with half the number of elements.
constexpr size_t modifiedElementSize = elements / 2;
// Get wait event from queue fill on buffer 0
uint64_t patternX = 42;
ASSERT_SUCCESS(urEnqueueMemBufferFill(queue, buffers[0], &patternX,
sizeof(patternX), 0, allocation_size,
0, nullptr, &external_events[0]));

// Test fill command overwriting buffer 0 based on queue event
uint64_t patternY = 0xA;
ASSERT_SUCCESS(urCommandBufferAppendMemBufferFillExp(
cmd_buf_handle, buffers[0], &patternY, sizeof(patternY), 0,
allocation_size, 0, nullptr, 1, &external_events[0], nullptr,
&external_events[1], nullptr));
ASSERT_SUCCESS(urCommandBufferFinalizeExp(cmd_buf_handle));
ASSERT_SUCCESS(
urCommandBufferEnqueueExp(cmd_buf_handle, queue, 0, nullptr, nullptr));

// Queue read buffer 0 based on event returned from command-buffer command
std::array<uint64_t, modifiedElementSize> host_enqueue_ptr{};
ASSERT_SUCCESS(urEnqueueMemBufferRead(
queue, buffers[0], false, 0, allocation_size, host_enqueue_ptr.data(),
1, &external_events[1], nullptr));

// Verify
ASSERT_SUCCESS(urQueueFinish(queue));
for (size_t i = 0; i < modifiedElementSize; i++) {
ASSERT_EQ(host_enqueue_ptr[i], patternY);
}
}

TEST_P(CommandEventSyncTest, USMPrefetchExp) {
// Get wait event from queue fill on ptr 0
uint32_t patternX = 42;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ CommandEventSyncTest.USMPrefetchExp/*
CommandEventSyncTest.USMAdviseExp/*
CommandEventSyncTest.MultipleEventCommands/*
CommandEventSyncTest.MultipleEventCommandsBetweenCommandBuffers/*
CommandEventSyncTest.USMFillLargePatternExp/*
CommandEventSyncTest.MemBufferFillLargePatternExp/*
CommandEventSyncUpdateTest.USMMemcpyExp/*
CommandEventSyncUpdateTest.USMFillExp/*
CommandEventSyncUpdateTest.MemBufferCopyExp/*
Expand All @@ -45,3 +47,5 @@ CommandEventSyncUpdateTest.MemBufferFillExp/*
CommandEventSyncUpdateTest.USMPrefetchExp/*
CommandEventSyncUpdateTest.USMAdviseExp/*
CommandEventSyncUpdateTest.MultipleEventCommands/*
CommandEventSyncUpdateTest.USMFillLargePatternExp/*
CommandEventSyncUpdateTest.MemBufferFillLargePatternExp/*
124 changes: 124 additions & 0 deletions test/conformance/exp_command_buffer/update/event_sync.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,68 @@ TEST_P(CommandEventSyncUpdateTest, USMFillExp) {
}
}

// Test fill using a large pattern size since implementations may need to handle
// this differently.
TEST_P(CommandEventSyncUpdateTest, USMFillLargePatternExp) {
// Device ptrs are allocated in the test fixture with 32-bit values * num
// elements, since we are doubling the pattern size we want to treat those
// device pointers as if they were created with half the number of elements.
constexpr size_t modifiedElementSize = elements / 2;
// Get wait event from queue fill on ptr 0
uint64_t patternX = 42;
ASSERT_SUCCESS(urEnqueueUSMFill(queue, device_ptrs[0], sizeof(patternX),
&patternX, allocation_size, 0, nullptr,
&external_events[0]));

// Test fill command overwriting ptr 0 waiting on queue event
uint64_t patternY = 0xA;
ASSERT_SUCCESS(urCommandBufferAppendUSMFillExp(
updatable_cmd_buf_handle, device_ptrs[0], &patternY, sizeof(patternY),
allocation_size, 0, nullptr, 1, &external_events[0], nullptr,
&external_events[1], &command_handles[0]));
ASSERT_NE(nullptr, command_handles[0]);
ASSERT_SUCCESS(urCommandBufferFinalizeExp(updatable_cmd_buf_handle));
ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0,
nullptr, nullptr));

// Queue read ptr 0 based on event returned from command-buffer command
std::array<uint64_t, modifiedElementSize> host_enqueue_ptr{};
ASSERT_SUCCESS(urEnqueueUSMMemcpy(queue, false, host_enqueue_ptr.data(),
device_ptrs[0], allocation_size, 1,
&external_events[1], nullptr));

// Verify
ASSERT_SUCCESS(urQueueFinish(queue));
for (size_t i = 0; i < modifiedElementSize; i++) {
ASSERT_EQ(host_enqueue_ptr[i], patternY);
}

uint64_t patternZ = 666;
ASSERT_SUCCESS(urEnqueueUSMFill(queue, device_ptrs[0], sizeof(patternZ),
&patternZ, allocation_size, 0, nullptr,
&external_events[2]));

// Update command command-wait event to wait on fill of new value
ASSERT_SUCCESS(urCommandBufferUpdateWaitEventsExp(command_handles[0], 1,
&external_events[2]));

// Get a new signal event for command-buffer
ASSERT_SUCCESS(urCommandBufferUpdateSignalEventExp(command_handles[0],
&external_events[3]));
ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0,
nullptr, nullptr));

ASSERT_SUCCESS(urEnqueueUSMMemcpy(queue, false, host_enqueue_ptr.data(),
device_ptrs[0], allocation_size, 1,
&external_events[3], nullptr));

// Verify update
ASSERT_SUCCESS(urQueueFinish(queue));
for (size_t i = 0; i < modifiedElementSize; i++) {
ASSERT_EQ(host_enqueue_ptr[i], patternY);
}
}

TEST_P(CommandEventSyncUpdateTest, MemBufferCopyExp) {
// Get wait event from queue fill on buffer 0
uint32_t patternX = 42;
Expand Down Expand Up @@ -532,6 +594,68 @@ TEST_P(CommandEventSyncUpdateTest, MemBufferWriteRectExp) {
}
}

// Test fill using a large pattern size since implementations may need to handle
// this differently.
TEST_P(CommandEventSyncUpdateTest, MemBufferFillLargePatternExp) {
// Device buffers are allocated in the test fixture with 32-bit values * num
// elements, since we are doubling the pattern size we want to treat those
// device pointers as if they were created with half the number of elements.
constexpr size_t modifiedElementSize = elements / 2;
// Get wait event from queue fill on buffer 0
uint64_t patternX = 42;
ASSERT_SUCCESS(urEnqueueMemBufferFill(queue, buffers[0], &patternX,
sizeof(patternX), 0, allocation_size,
0, nullptr, &external_events[0]));

// Test fill command overwriting buffer 0 based on queue event
uint64_t patternY = 0xA;
ASSERT_SUCCESS(urCommandBufferAppendMemBufferFillExp(
updatable_cmd_buf_handle, buffers[0], &patternY, sizeof(patternY), 0,
allocation_size, 0, nullptr, 1, &external_events[0], nullptr,
&external_events[1], &command_handles[0]));
ASSERT_NE(nullptr, command_handles[0]);
ASSERT_SUCCESS(urCommandBufferFinalizeExp(updatable_cmd_buf_handle));
ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0,
nullptr, nullptr));

// Queue read buffer 0 based on event returned from command-buffer command
std::array<uint64_t, modifiedElementSize> host_enqueue_ptr{};
ASSERT_SUCCESS(urEnqueueMemBufferRead(
queue, buffers[0], false, 0, allocation_size, host_enqueue_ptr.data(),
1, &external_events[1], nullptr));

// Verify
ASSERT_SUCCESS(urQueueFinish(queue));
for (size_t i = 0; i < modifiedElementSize; i++) {
ASSERT_EQ(host_enqueue_ptr[i], patternY);
}

uint64_t patternZ = 666;
ASSERT_SUCCESS(urEnqueueMemBufferFill(queue, buffers[0], &patternZ,
sizeof(patternZ), 0, allocation_size,
0, nullptr, &external_events[2]));

// Update command command-wait event to wait on fill of new value
ASSERT_SUCCESS(urCommandBufferUpdateWaitEventsExp(command_handles[0], 1,
&external_events[2]));

// Get a new signal event for command-buffer
ASSERT_SUCCESS(urCommandBufferUpdateSignalEventExp(command_handles[0],
&external_events[3]));

ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0,
nullptr, nullptr));
ASSERT_SUCCESS(urEnqueueMemBufferRead(
queue, buffers[0], false, 0, allocation_size, host_enqueue_ptr.data(),
1, &external_events[3], nullptr));

// Verify update
ASSERT_SUCCESS(urQueueFinish(queue));
for (size_t i = 0; i < modifiedElementSize; i++) {
ASSERT_EQ(host_enqueue_ptr[i], patternY);
}
}

TEST_P(CommandEventSyncUpdateTest, MemBufferFillExp) {
// Get wait event from queue fill on buffer 0
uint32_t patternX = 42;
Expand Down
Loading