Skip to content

Commit 3d9aa64

Browse files
committed
Merge branch 'main' into user-after-free
2 parents e198d84 + c49b116 commit 3d9aa64

File tree

18 files changed

+264
-71
lines changed

18 files changed

+264
-71
lines changed

source/adapters/cuda/command_buffer.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
368368
try {
369369
// Set node param structure with the kernel related data
370370
auto &ArgIndices = hKernel->getArgIndices();
371-
CUDA_KERNEL_NODE_PARAMS NodeParams;
371+
CUDA_KERNEL_NODE_PARAMS NodeParams = {};
372372
NodeParams.func = CuFunc;
373373
NodeParams.gridDimX = BlocksPerGrid[0];
374374
NodeParams.gridDimY = BlocksPerGrid[1];
@@ -378,8 +378,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
378378
NodeParams.blockDimZ = ThreadsPerBlock[2];
379379
NodeParams.sharedMemBytes = LocalSize;
380380
NodeParams.kernelParams = const_cast<void **>(ArgIndices.data());
381-
NodeParams.kern = nullptr;
382-
NodeParams.extra = nullptr;
383381

384382
// Create and add an new kernel node to the Cuda graph
385383
UR_CHECK_ERROR(cuGraphAddKernelNode(&GraphNode, hCommandBuffer->CudaGraph,

source/adapters/level_zero/device.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,6 +1054,19 @@ bool ur_device_handle_t_::useRelaxedAllocationLimits() {
10541054
return EnableRelaxedAllocationLimits;
10551055
}
10561056

1057+
bool ur_device_handle_t_::useDriverInOrderLists() {
1058+
// Use in-order lists implementation from L0 driver instead
1059+
// of adapter's implementation.
1060+
static const bool UseDriverInOrderLists = [] {
1061+
const char *UrRet = std::getenv("UR_L0_USE_DRIVER_INORDER_LISTS");
1062+
if (!UrRet)
1063+
return false;
1064+
return std::atoi(UrRet) != 0;
1065+
}();
1066+
1067+
return UseDriverInOrderLists;
1068+
}
1069+
10571070
ur_result_t ur_device_handle_t_::initialize(int SubSubDeviceOrdinal,
10581071
int SubSubDeviceIndex) {
10591072
// Maintain various device properties cache.

source/adapters/level_zero/device.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ struct ur_device_handle_t_ : _ur_object {
143143
// Read env settings to select immediate commandlist mode.
144144
ImmCmdlistMode useImmediateCommandLists();
145145

146+
// Whether Adapter uses driver's implementation of in-order lists or not
147+
bool useDriverInOrderLists();
148+
146149
// Returns whether immediate command lists are used on this device.
147150
ImmCmdlistMode ImmCommandListUsed{};
148151

source/adapters/level_zero/event.cpp

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,20 @@ static const bool UseMultipleCmdlistBarriers = [] {
4343
return std::atoi(UseMultipleCmdlistBarriersFlag) > 0;
4444
}();
4545

46+
bool WaitListEmptyOrAllEventsFromSameQueue(
47+
ur_queue_handle_t Queue, uint32_t NumEventsInWaitList,
48+
const ur_event_handle_t *EventWaitList) {
49+
if (!NumEventsInWaitList)
50+
return true;
51+
52+
for (uint32_t i = 0; i < NumEventsInWaitList; ++i) {
53+
if (Queue != EventWaitList[i]->UrQueue)
54+
return false;
55+
}
56+
57+
return true;
58+
}
59+
4660
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWait(
4761
ur_queue_handle_t Queue, ///< [in] handle of the queue object
4862
uint32_t NumEventsInWaitList, ///< [in] size of the event wait list
@@ -206,21 +220,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWaitWithBarrier(
206220
bool IsInternal = OutEvent == nullptr;
207221
ur_event_handle_t *Event = OutEvent ? OutEvent : &InternalEvent;
208222

209-
auto WaitListEmptyOrAllEventsFromSameQueue = [Queue, NumEventsInWaitList,
210-
EventWaitList]() {
211-
if (!NumEventsInWaitList)
212-
return true;
213-
214-
for (uint32_t I = 0; I < NumEventsInWaitList; ++I)
215-
if (Queue != EventWaitList[I]->UrQueue)
216-
return false;
217-
218-
return true;
219-
};
220-
221223
// For in-order queue and wait-list which is empty or has events from
222224
// the same queue just use the last command event as the barrier event.
223-
if (Queue->isInOrderQueue() && WaitListEmptyOrAllEventsFromSameQueue() &&
225+
if (Queue->isInOrderQueue() &&
226+
WaitListEmptyOrAllEventsFromSameQueue(Queue, NumEventsInWaitList,
227+
EventWaitList) &&
224228
Queue->LastCommandEvent && !Queue->LastCommandEvent->IsDiscarded) {
225229
UR_CALL(urEventRetain(Queue->LastCommandEvent));
226230
*Event = Queue->LastCommandEvent;
@@ -1189,6 +1193,23 @@ ur_result_t _ur_ze_event_list_t::createAndRetainUrZeEventList(
11891193
CurQueue->LastCommandEvent && CurQueue->LastCommandEvent->IsDiscarded)
11901194
IncludeLastCommandEvent = false;
11911195

1196+
// If we are using L0 native implementation for handling in-order queues,
1197+
// then we don't need to add the last enqueued event into the waitlist, as
1198+
// the native driver implementation will already ensure in-order semantics.
1199+
// The only exception is when a different immediate command was last used on
1200+
// the same UR Queue.
1201+
if (CurQueue->Device->useDriverInOrderLists() && CurQueue->isInOrderQueue() &&
1202+
CurQueue->UsingImmCmdLists) {
1203+
auto QueueGroup = CurQueue->getQueueGroup(UseCopyEngine);
1204+
uint32_t QueueGroupOrdinal, QueueIndex;
1205+
auto NextIndex = QueueGroup.getQueueIndex(&QueueGroupOrdinal, &QueueIndex,
1206+
/*QueryOnly */ true);
1207+
auto NextImmCmdList = QueueGroup.ImmCmdLists[NextIndex];
1208+
IncludeLastCommandEvent &=
1209+
CurQueue->LastUsedCommandList != CurQueue->CommandListMap.end() &&
1210+
NextImmCmdList != CurQueue->LastUsedCommandList;
1211+
}
1212+
11921213
try {
11931214
uint32_t TmpListLength = 0;
11941215

@@ -1205,6 +1226,16 @@ ur_result_t _ur_ze_event_list_t::createAndRetainUrZeEventList(
12051226
this->UrEventList = new ur_event_handle_t[EventListLength];
12061227
}
12071228

1229+
// For in-order queue and wait-list which is empty or has events only from
1230+
// the same queue then we don't need to wait on any other additional events
1231+
if (CurQueue->Device->useDriverInOrderLists() &&
1232+
CurQueue->isInOrderQueue() &&
1233+
WaitListEmptyOrAllEventsFromSameQueue(CurQueue, EventListLength,
1234+
EventList)) {
1235+
this->Length = TmpListLength;
1236+
return UR_RESULT_SUCCESS;
1237+
}
1238+
12081239
if (EventListLength > 0) {
12091240
for (uint32_t I = 0; I < EventListLength; I++) {
12101241
{

source/adapters/level_zero/image.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -738,13 +738,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesSampledImageCreateExp(
738738
hContext, hDevice, hImageMem, pImageFormat, pImageDesc, phMem, phImage));
739739

740740
struct combined_sampled_image_handle {
741-
uint64_t raw_image_handle;
742-
uint64_t raw_sampler_handle;
741+
uint64_t RawImageHandle;
742+
uint64_t RawSamplerHandle;
743743
};
744-
combined_sampled_image_handle *sampledImageHandle =
744+
auto *SampledImageHandle =
745745
reinterpret_cast<combined_sampled_image_handle *>(phImage);
746-
sampledImageHandle->raw_image_handle = reinterpret_cast<uint64_t>(*phImage);
747-
sampledImageHandle->raw_sampler_handle =
746+
SampledImageHandle->RawSamplerHandle =
748747
reinterpret_cast<uint64_t>(hSampler->ZeSampler);
749748

750749
return UR_RESULT_SUCCESS;

source/adapters/level_zero/kernel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
214214
// the code can do a urKernelRelease on this kernel.
215215
(*Event)->CommandData = (void *)Kernel;
216216

217-
// Increment the reference count of the Kernel and indicate that the Kernel is
218-
// in use. Once the event has been signalled, the code in
217+
// Increment the reference count of the Kernel and indicate that the Kernel
218+
// is in use. Once the event has been signalled, the code in
219219
// CleanupCompletedEvent(Event) will do a urKernelRelease to update the
220220
// reference count on the kernel, using the kernel saved in CommandData.
221221
UR_CALL(urKernelRetain(Kernel));

source/adapters/level_zero/queue.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1870,6 +1870,10 @@ ur_result_t ur_queue_handle_t_::createCommandList(
18701870
ZeStruct<ze_command_list_desc_t> ZeCommandListDesc;
18711871
ZeCommandListDesc.commandQueueGroupOrdinal = QueueGroupOrdinal;
18721872

1873+
if (Device->useDriverInOrderLists() && isInOrderQueue()) {
1874+
ZeCommandListDesc.flags = ZE_COMMAND_LIST_FLAG_IN_ORDER;
1875+
}
1876+
18731877
ZE2UR_CALL(zeCommandListCreate, (Context->ZeContext, Device->ZeDevice,
18741878
&ZeCommandListDesc, &ZeCommandList));
18751879

@@ -1985,7 +1989,11 @@ ur_command_list_ptr_t &ur_queue_handle_t_::ur_queue_group_t::getImmCmdList() {
19851989

19861990
// Evaluate performance of explicit usage for "0" index.
19871991
if (QueueIndex != 0) {
1988-
ZeCommandQueueDesc.flags = ZE_COMMAND_QUEUE_FLAG_EXPLICIT_ONLY;
1992+
ZeCommandQueueDesc.flags |= ZE_COMMAND_QUEUE_FLAG_EXPLICIT_ONLY;
1993+
}
1994+
1995+
if (Queue->Device->useDriverInOrderLists() && Queue->isInOrderQueue()) {
1996+
ZeCommandQueueDesc.flags |= ZE_COMMAND_QUEUE_FLAG_IN_ORDER;
19891997
}
19901998

19911999
// Check if context's command list cache has an immediate command list with

source/loader/ur_lib.cpp

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -224,10 +224,6 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
224224
if (!hPlatform) {
225225
return UR_RESULT_ERROR_INVALID_NULL_HANDLE;
226226
}
227-
// NumEntries is max number of devices wanted by the caller (max usable length of phDevices)
228-
if (NumEntries < 0) {
229-
return UR_RESULT_ERROR_INVALID_SIZE;
230-
}
231227
if (NumEntries > 0 && !phDevices) {
232228
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
233229
}
@@ -426,8 +422,10 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
426422

427423
for (auto &termPair : mapODS) {
428424
std::string backend = termPair.first;
429-
if (backend
430-
.empty()) { // FIXME: never true because getenv_to_map rejects this case
425+
// TODO: Figure out how to process all ODS errors rather than returning
426+
// on the first error.
427+
if (backend.empty()) {
428+
// FIXME: never true because getenv_to_map rejects this case
431429
// malformed term: missing backend -- output ERROR, then continue
432430
logger::error("ERROR: missing backend, format of filter = "
433431
"'[!]backend:filterStrings'");
@@ -459,20 +457,19 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
459457
std::tolower(static_cast<unsigned char>(b));
460458
})) {
461459
// irrelevant term for current request: different backend -- silently ignore
462-
logger::warning(
463-
"WARNING: ignoring term with irrelevant backend '{}'", backend);
464-
continue;
460+
logger::error("unrecognised backend '{}'", backend);
461+
return UR_RESULT_ERROR_INVALID_VALUE;
465462
}
466463
if (termPair.second.size() == 0) {
467-
// malformed term: missing filterStrings -- output ERROR, then continue
468-
logger::error("ERROR missing filterStrings, format of filter = "
464+
// malformed term: missing filterStrings -- output ERROR
465+
logger::error("missing filterStrings, format of filter = "
469466
"'[!]backend:filterStrings'");
470-
continue;
467+
return UR_RESULT_ERROR_INVALID_VALUE;
471468
}
472469
if (std::find_if(termPair.second.cbegin(), termPair.second.cend(),
473470
[](const auto &s) { return s.empty(); }) !=
474-
termPair.second
475-
.cend()) { // FIXME: never true because getenv_to_map rejects this case
471+
termPair.second.cend()) {
472+
// FIXME: never true because getenv_to_map rejects this case
476473
// malformed term: missing filterString -- output warning, then continue
477474
logger::warning(
478475
"WARNING: empty filterString, format of filterStrings "
@@ -483,10 +480,10 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
483480
[](const auto &s) {
484481
return std::count(s.cbegin(), s.cend(), '.') > 2;
485482
}) != termPair.second.cend()) {
486-
// malformed term: too many dots in filterString -- output warning, then continue
487-
logger::warning("WARNING: too many dots in filterString, format of "
488-
"filterString = 'root[.sub[.subsub]]'");
489-
continue;
483+
// malformed term: too many dots in filterString
484+
logger::error("too many dots in filterString, format of "
485+
"filterString = 'root[.sub[.subsub]]'");
486+
return UR_RESULT_ERROR_INVALID_VALUE;
490487
}
491488
if (std::find_if(
492489
termPair.second.cbegin(), termPair.second.cend(),
@@ -504,10 +501,9 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
504501
}
505502
return false; // no BAD things, so must be okay
506503
}) != termPair.second.cend()) {
507-
// malformed term: star dot no-star in filterString -- output warning, then continue
508-
logger::warning(
509-
"WARNING: invalid wildcard in filterString, '*.' => '*.*'");
510-
continue;
504+
// malformed term: star dot no-star in filterString
505+
logger::error("invalid wildcard in filterString, '*.' => '*.*'");
506+
return UR_RESULT_ERROR_INVALID_VALUE;
511507
}
512508

513509
// TODO -- use regex validation_pattern to catch all other syntax errors in the ODS string
@@ -552,7 +548,7 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
552548

553549
if (acceptDeviceList.size() == 0 && discardDeviceList.size() == 0) {
554550
// nothing in env var was understood as a valid term
555-
return UR_RESULT_ERROR_INVALID_VALUE;
551+
return UR_RESULT_SUCCESS;
556552
} else if (acceptDeviceList.size() == 0) {
557553
// no accept terms were understood, but at least one discard term was
558554
// we are magnanimous to the user when there were bad/ignored accept terms

test/conformance/device/urDeviceGet.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,42 @@ TEST_F(urDeviceGetTest, SuccessSubsetOfDevices) {
3535
}
3636
}
3737

38+
struct urDeviceGetTestWithDeviceTypeParam
39+
: uur::urAllDevicesTest,
40+
::testing::WithParamInterface<ur_device_type_t> {
41+
42+
void SetUp() override {
43+
UUR_RETURN_ON_FATAL_FAILURE(uur::urAllDevicesTest::SetUp());
44+
}
45+
};
46+
47+
INSTANTIATE_TEST_SUITE_P(
48+
, urDeviceGetTestWithDeviceTypeParam,
49+
::testing::Values(UR_DEVICE_TYPE_DEFAULT, UR_DEVICE_TYPE_GPU,
50+
UR_DEVICE_TYPE_CPU, UR_DEVICE_TYPE_FPGA,
51+
UR_DEVICE_TYPE_MCA, UR_DEVICE_TYPE_VPU),
52+
[](const ::testing::TestParamInfo<ur_device_type_t> &info) {
53+
std::stringstream ss;
54+
ss << info.param;
55+
return ss.str();
56+
});
57+
58+
TEST_P(urDeviceGetTestWithDeviceTypeParam, Success) {
59+
ur_device_type_t device_type = GetParam();
60+
uint32_t count = 0;
61+
ASSERT_SUCCESS(urDeviceGet(platform, device_type, 0, nullptr, &count));
62+
ASSERT_GE(devices.size(), count);
63+
64+
if (count > 0) {
65+
std::vector<ur_device_handle_t> devices(count);
66+
ASSERT_SUCCESS(
67+
urDeviceGet(platform, device_type, count, devices.data(), nullptr));
68+
for (auto device : devices) {
69+
ASSERT_NE(nullptr, device);
70+
}
71+
}
72+
}
73+
3874
TEST_F(urDeviceGetTest, InvalidNullHandlePlatform) {
3975
uint32_t count;
4076
ASSERT_EQ_RESULT(

test/conformance/enqueue/urEnqueueMemBufferMap.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,21 @@ TEST_P(urEnqueueMemBufferMapTest, SuccessRead) {
2121
}
2222
}
2323

24-
TEST_P(urEnqueueMemBufferMapTest, SuccessWrite) {
24+
using urEnqueueMemBufferMapTestWithWriteFlagParam =
25+
uur::urMemBufferQueueTestWithParam<ur_map_flag_t>;
26+
UUR_TEST_SUITE_P(urEnqueueMemBufferMapTestWithWriteFlagParam,
27+
::testing::Values(UR_MAP_FLAG_WRITE,
28+
UR_MAP_FLAG_WRITE_INVALIDATE_REGION),
29+
uur::deviceTestWithParamPrinter<ur_map_flag_t>);
30+
31+
TEST_P(urEnqueueMemBufferMapTestWithWriteFlagParam, SuccessWrite) {
2532
const std::vector<uint32_t> input(count, 0);
2633
ASSERT_SUCCESS(urEnqueueMemBufferWrite(queue, buffer, true, 0, size,
2734
input.data(), 0, nullptr, nullptr));
2835

2936
uint32_t *map = nullptr;
30-
ASSERT_SUCCESS(urEnqueueMemBufferMap(queue, buffer, true, UR_MAP_FLAG_WRITE,
31-
0, size, 0, nullptr, nullptr,
37+
ASSERT_SUCCESS(urEnqueueMemBufferMap(queue, buffer, true, getParam(), 0,
38+
size, 0, nullptr, nullptr,
3239
(void **)&map));
3340
for (unsigned i = 0; i < count; ++i) {
3441
map[i] = 42;

0 commit comments

Comments
 (0)