Skip to content

Commit 5df5f45

Browse files
authored
[UR][SYCL] Add support for zeCommandListAppendLaunchKernelWithArguments() (#20316)
Signed-off-by: Lukasz Dorau <[email protected]>
1 parent 8b4629c commit 5df5f45

File tree

89 files changed

+3663
-453
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

89 files changed

+3663
-453
lines changed

sycl/source/detail/scheduler/commands.cpp

Lines changed: 174 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2303,14 +2303,22 @@ ur_mem_flags_t AccessModeToUr(access::mode AccessorMode) {
23032303
}
23042304
}
23052305

2306-
// Sets arguments for a given kernel and device based on the argument type.
2307-
// Refactored from SetKernelParamsAndLaunch to allow it to be used in the graphs
2308-
// extension.
2309-
static void SetArgBasedOnType(
2310-
adapter_impl &Adapter, ur_kernel_handle_t Kernel,
2306+
// Gets UR argument struct for a given kernel and device based on the argument
2307+
// type. Refactored from SetKernelParamsAndLaunch to allow it to be used in
2308+
// the graphs extension (LaunchWithArgs for graphs is planned future work).
2309+
static void GetUrArgsBasedOnType(
23112310
device_image_impl *DeviceImageImpl,
23122311
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2313-
context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
2312+
context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex,
2313+
std::vector<ur_exp_kernel_arg_properties_t> &UrArgs) {
2314+
// UrArg.size == 0 indicates uninitialized structure
2315+
ur_exp_kernel_arg_properties_t UrArg = {
2316+
UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2317+
nullptr,
2318+
UR_EXP_KERNEL_ARG_TYPE_VALUE,
2319+
static_cast<uint32_t>(NextTrueIndex),
2320+
0,
2321+
{}};
23142322
switch (Arg.MType) {
23152323
case kernel_param_kind_t::kind_dynamic_work_group_memory:
23162324
break;
@@ -2330,52 +2338,56 @@ static void SetArgBasedOnType(
23302338
getMemAllocationFunc
23312339
? reinterpret_cast<ur_mem_handle_t>(getMemAllocationFunc(Req))
23322340
: nullptr;
2333-
ur_kernel_arg_mem_obj_properties_t MemObjData{};
2334-
MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2335-
MemObjData.memoryAccess = AccessModeToUr(Req->MAccessMode);
2336-
Adapter.call<UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2337-
&MemObjData, MemArg);
2341+
ur_exp_kernel_arg_value_t Value = {};
2342+
Value.memObjTuple = {MemArg, AccessModeToUr(Req->MAccessMode)};
2343+
UrArg.type = UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ;
2344+
UrArg.size = sizeof(MemArg);
2345+
UrArg.value = Value;
23382346
break;
23392347
}
23402348
case kernel_param_kind_t::kind_std_layout: {
2349+
ur_exp_kernel_arg_type_t Type;
23412350
if (Arg.MPtr) {
2342-
Adapter.call<UrApiKind::urKernelSetArgValue>(
2343-
Kernel, NextTrueIndex, Arg.MSize, nullptr, Arg.MPtr);
2351+
Type = UR_EXP_KERNEL_ARG_TYPE_VALUE;
23442352
} else {
2345-
Adapter.call<UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2346-
Arg.MSize, nullptr);
2353+
Type = UR_EXP_KERNEL_ARG_TYPE_LOCAL;
23472354
}
2348-
2355+
ur_exp_kernel_arg_value_t Value = {};
2356+
Value.value = {Arg.MPtr};
2357+
UrArg.type = Type;
2358+
UrArg.size = static_cast<size_t>(Arg.MSize);
2359+
UrArg.value = Value;
23492360
break;
23502361
}
23512362
case kernel_param_kind_t::kind_sampler: {
23522363
sampler *SamplerPtr = (sampler *)Arg.MPtr;
2353-
ur_sampler_handle_t Sampler =
2354-
(ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr)
2355-
->getOrCreateSampler(ContextImpl);
2356-
Adapter.call<UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2357-
nullptr, Sampler);
2364+
ur_exp_kernel_arg_value_t Value = {};
2365+
Value.sampler = (ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr)
2366+
->getOrCreateSampler(ContextImpl);
2367+
UrArg.type = UR_EXP_KERNEL_ARG_TYPE_SAMPLER;
2368+
UrArg.size = sizeof(ur_sampler_handle_t);
2369+
UrArg.value = Value;
23582370
break;
23592371
}
23602372
case kernel_param_kind_t::kind_pointer: {
2361-
// We need to de-rerence this to get the actual USM allocation - that's the
2373+
ur_exp_kernel_arg_value_t Value = {};
2374+
// We need to de-rerence to get the actual USM allocation - that's the
23622375
// pointer UR is expecting.
2363-
const void *Ptr = *static_cast<const void *const *>(Arg.MPtr);
2364-
Adapter.call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2365-
nullptr, Ptr);
2376+
Value.pointer = *static_cast<void *const *>(Arg.MPtr);
2377+
UrArg.type = UR_EXP_KERNEL_ARG_TYPE_POINTER;
2378+
UrArg.size = sizeof(Arg.MPtr);
2379+
UrArg.value = Value;
23662380
break;
23672381
}
23682382
case kernel_param_kind_t::kind_specialization_constants_buffer: {
23692383
assert(DeviceImageImpl != nullptr);
23702384
ur_mem_handle_t SpecConstsBuffer =
23712385
DeviceImageImpl->get_spec_const_buffer_ref();
2372-
2373-
ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2374-
MemObjProps.pNext = nullptr;
2375-
MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2376-
MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2377-
Adapter.call<UrApiKind::urKernelSetArgMemObj>(
2378-
Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
2386+
ur_exp_kernel_arg_value_t Value = {};
2387+
Value.memObjTuple = {SpecConstsBuffer, UR_MEM_FLAG_READ_ONLY};
2388+
UrArg.type = UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ;
2389+
UrArg.size = sizeof(SpecConstsBuffer);
2390+
UrArg.value = Value;
23792391
break;
23802392
}
23812393
case kernel_param_kind_t::kind_invalid:
@@ -2384,6 +2396,10 @@ static void SetArgBasedOnType(
23842396
codeToString(UR_RESULT_ERROR_INVALID_VALUE));
23852397
break;
23862398
}
2399+
2400+
if (UrArg.size) {
2401+
UrArgs.push_back(UrArg);
2402+
}
23872403
}
23882404

23892405
static ur_result_t SetKernelParamsAndLaunch(
@@ -2404,22 +2420,33 @@ static ur_result_t SetKernelParamsAndLaunch(
24042420
DeviceImageImpl ? DeviceImageImpl->get_spec_const_blob_ref() : Empty);
24052421
}
24062422

2423+
// just a performance optimization - avoid heap allocations
2424+
static thread_local std::vector<ur_exp_kernel_arg_properties_t> UrArgs;
2425+
UrArgs.clear();
2426+
UrArgs.reserve(Args.size());
2427+
24072428
if (KernelFuncPtr && !DeviceKernelInfo.HasSpecialCaptures) {
2408-
auto setFunc = [&Adapter, Kernel,
2409-
KernelFuncPtr](const detail::kernel_param_desc_t &ParamDesc,
2429+
auto setFunc = [KernelFuncPtr](const detail::kernel_param_desc_t &ParamDesc,
24102430
size_t NextTrueIndex) {
24112431
const void *ArgPtr = (const char *)KernelFuncPtr + ParamDesc.offset;
24122432
switch (ParamDesc.kind) {
24132433
case kernel_param_kind_t::kind_std_layout: {
24142434
int Size = ParamDesc.info;
2415-
Adapter.call<UrApiKind::urKernelSetArgValue>(Kernel, NextTrueIndex,
2416-
Size, nullptr, ArgPtr);
2435+
ur_exp_kernel_arg_value_t Value = {};
2436+
Value.value = ArgPtr;
2437+
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
2438+
UR_EXP_KERNEL_ARG_TYPE_VALUE,
2439+
static_cast<uint32_t>(NextTrueIndex),
2440+
static_cast<size_t>(Size), Value});
24172441
break;
24182442
}
24192443
case kernel_param_kind_t::kind_pointer: {
2420-
const void *Ptr = *static_cast<const void *const *>(ArgPtr);
2421-
Adapter.call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2422-
nullptr, Ptr);
2444+
ur_exp_kernel_arg_value_t Value = {};
2445+
Value.pointer = *static_cast<const void *const *>(ArgPtr);
2446+
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
2447+
UR_EXP_KERNEL_ARG_TYPE_POINTER,
2448+
static_cast<uint32_t>(NextTrueIndex),
2449+
sizeof(Value.pointer), Value});
24232450
break;
24242451
}
24252452
default:
@@ -2429,23 +2456,28 @@ static ur_result_t SetKernelParamsAndLaunch(
24292456
applyFuncOnFilteredArgs(EliminatedArgMask, DeviceKernelInfo.NumParams,
24302457
DeviceKernelInfo.ParamDescGetter, setFunc);
24312458
} else {
2432-
auto setFunc = [&Adapter, Kernel, &DeviceImageImpl, &getMemAllocationFunc,
2459+
auto setFunc = [&DeviceImageImpl, &getMemAllocationFunc,
24332460
&Queue](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2434-
SetArgBasedOnType(Adapter, Kernel, DeviceImageImpl, getMemAllocationFunc,
2435-
Queue.getContextImpl(), Arg, NextTrueIndex);
2461+
GetUrArgsBasedOnType(DeviceImageImpl, getMemAllocationFunc,
2462+
Queue.getContextImpl(), Arg, NextTrueIndex, UrArgs);
24362463
};
24372464
applyFuncOnFilteredArgs(EliminatedArgMask, Args, setFunc);
24382465
}
24392466

2440-
const std::optional<int> &ImplicitLocalArg =
2441-
DeviceKernelInfo.getImplicitLocalArgPos();
2467+
std::optional<int> ImplicitLocalArg =
2468+
ProgramManager::getInstance().kernelImplicitLocalArgPos(
2469+
DeviceKernelInfo.Name);
24422470
// Set the implicit local memory buffer to support
24432471
// get_work_group_scratch_memory. This is for backend not supporting
24442472
// CUDA-style local memory setting. Note that we may have -1 as a position,
24452473
// this indicates the buffer is actually unused and was elided.
24462474
if (ImplicitLocalArg.has_value() && ImplicitLocalArg.value() != -1) {
2447-
Adapter.call<UrApiKind::urKernelSetArgLocal>(
2448-
Kernel, ImplicitLocalArg.value(), WorkGroupMemorySize, nullptr);
2475+
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2476+
nullptr,
2477+
UR_EXP_KERNEL_ARG_TYPE_LOCAL,
2478+
static_cast<uint32_t>(ImplicitLocalArg.value()),
2479+
WorkGroupMemorySize,
2480+
{nullptr}});
24492481
}
24502482

24512483
adjustNDRangePerKernel(NDRDesc, Kernel, Queue.getDeviceImpl());
@@ -2468,16 +2500,14 @@ static ur_result_t SetKernelParamsAndLaunch(
24682500
/* pPropSizeRet = */ nullptr);
24692501

24702502
const bool EnforcedLocalSize =
2471-
(RequiredWGSize[0] != 0 &&
2472-
(NDRDesc.Dims < 2 || RequiredWGSize[1] != 0) &&
2473-
(NDRDesc.Dims < 3 || RequiredWGSize[2] != 0));
2503+
(RequiredWGSize[0] != 0 || RequiredWGSize[1] != 0 ||
2504+
RequiredWGSize[2] != 0);
24742505
if (EnforcedLocalSize)
24752506
LocalSize = RequiredWGSize;
24762507
}
2477-
2478-
const bool HasOffset = NDRDesc.GlobalOffset[0] != 0 &&
2479-
(NDRDesc.Dims < 2 || NDRDesc.GlobalOffset[1] != 0) &&
2480-
(NDRDesc.Dims < 3 || NDRDesc.GlobalOffset[2] != 0);
2508+
const bool HasOffset = NDRDesc.GlobalOffset[0] != 0 ||
2509+
NDRDesc.GlobalOffset[1] != 0 ||
2510+
NDRDesc.GlobalOffset[2] != 0;
24812511

24822512
std::vector<ur_kernel_launch_property_t> property_list;
24832513

@@ -2505,20 +2535,104 @@ static ur_result_t SetKernelParamsAndLaunch(
25052535
{{WorkGroupMemorySize}}});
25062536
}
25072537
ur_event_handle_t UREvent = nullptr;
2508-
ur_result_t Error = Adapter.call_nocheck<UrApiKind::urEnqueueKernelLaunch>(
2509-
Queue.getHandleRef(), Kernel, NDRDesc.Dims,
2510-
HasOffset ? &NDRDesc.GlobalOffset[0] : nullptr, &NDRDesc.GlobalSize[0],
2511-
LocalSize, property_list.size(),
2512-
property_list.empty() ? nullptr : property_list.data(), RawEvents.size(),
2513-
RawEvents.empty() ? nullptr : &RawEvents[0],
2514-
OutEventImpl ? &UREvent : nullptr);
2538+
ur_result_t Error =
2539+
Adapter.call_nocheck<UrApiKind::urEnqueueKernelLaunchWithArgsExp>(
2540+
Queue.getHandleRef(), Kernel, NDRDesc.Dims,
2541+
HasOffset ? &NDRDesc.GlobalOffset[0] : nullptr,
2542+
&NDRDesc.GlobalSize[0], LocalSize, UrArgs.size(), UrArgs.data(),
2543+
property_list.size(),
2544+
property_list.empty() ? nullptr : property_list.data(),
2545+
RawEvents.size(), RawEvents.empty() ? nullptr : &RawEvents[0],
2546+
OutEventImpl ? &UREvent : nullptr);
25152547
if (Error == UR_RESULT_SUCCESS && OutEventImpl) {
25162548
OutEventImpl->setHandle(UREvent);
25172549
}
25182550

25192551
return Error;
25202552
}
25212553

2554+
// Sets arguments for a given kernel and device based on the argument type.
2555+
// This is a legacy path which the graphs extension still uses.
2556+
static void SetArgBasedOnType(
2557+
adapter_impl &Adapter, ur_kernel_handle_t Kernel,
2558+
device_image_impl *DeviceImageImpl,
2559+
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2560+
context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
2561+
switch (Arg.MType) {
2562+
case kernel_param_kind_t::kind_dynamic_work_group_memory:
2563+
break;
2564+
case kernel_param_kind_t::kind_work_group_memory:
2565+
break;
2566+
case kernel_param_kind_t::kind_stream:
2567+
break;
2568+
case kernel_param_kind_t::kind_dynamic_accessor:
2569+
case kernel_param_kind_t::kind_accessor: {
2570+
Requirement *Req = (Requirement *)(Arg.MPtr);
2571+
2572+
// getMemAllocationFunc is nullptr when there are no requirements. However,
2573+
// we may pass default constructed accessors to a command, which don't add
2574+
// requirements. In such case, getMemAllocationFunc is nullptr, but it's a
2575+
// valid case, so we need to properly handle it.
2576+
ur_mem_handle_t MemArg =
2577+
getMemAllocationFunc
2578+
? reinterpret_cast<ur_mem_handle_t>(getMemAllocationFunc(Req))
2579+
: nullptr;
2580+
ur_kernel_arg_mem_obj_properties_t MemObjData{};
2581+
MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2582+
MemObjData.memoryAccess = AccessModeToUr(Req->MAccessMode);
2583+
Adapter.call<UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2584+
&MemObjData, MemArg);
2585+
break;
2586+
}
2587+
case kernel_param_kind_t::kind_std_layout: {
2588+
if (Arg.MPtr) {
2589+
Adapter.call<UrApiKind::urKernelSetArgValue>(
2590+
Kernel, NextTrueIndex, Arg.MSize, nullptr, Arg.MPtr);
2591+
} else {
2592+
Adapter.call<UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2593+
Arg.MSize, nullptr);
2594+
}
2595+
2596+
break;
2597+
}
2598+
case kernel_param_kind_t::kind_sampler: {
2599+
sampler *SamplerPtr = (sampler *)Arg.MPtr;
2600+
ur_sampler_handle_t Sampler =
2601+
(ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr)
2602+
->getOrCreateSampler(ContextImpl);
2603+
Adapter.call<UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2604+
nullptr, Sampler);
2605+
break;
2606+
}
2607+
case kernel_param_kind_t::kind_pointer: {
2608+
// We need to de-rerence this to get the actual USM allocation - that's the
2609+
// pointer UR is expecting.
2610+
const void *Ptr = *static_cast<const void *const *>(Arg.MPtr);
2611+
Adapter.call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2612+
nullptr, Ptr);
2613+
break;
2614+
}
2615+
case kernel_param_kind_t::kind_specialization_constants_buffer: {
2616+
assert(DeviceImageImpl != nullptr);
2617+
ur_mem_handle_t SpecConstsBuffer =
2618+
DeviceImageImpl->get_spec_const_buffer_ref();
2619+
2620+
ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2621+
MemObjProps.pNext = nullptr;
2622+
MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2623+
MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2624+
Adapter.call<UrApiKind::urKernelSetArgMemObj>(
2625+
Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
2626+
break;
2627+
}
2628+
case kernel_param_kind_t::kind_invalid:
2629+
throw sycl::exception(sycl::make_error_code(sycl::errc::runtime),
2630+
"Invalid kernel param kind " +
2631+
codeToString(UR_RESULT_ERROR_INVALID_VALUE));
2632+
break;
2633+
}
2634+
}
2635+
25222636
static std::tuple<ur_kernel_handle_t, device_image_impl *,
25232637
const KernelArgMask *>
25242638
getCGKernelInfo(const CGExecKernel &CommandGroup, context_impl &ContextImpl,

sycl/test-e2e/Adapters/level_zero/batch_barrier.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ int main(int argc, char *argv[]) {
2424
queue q;
2525

2626
submit_kernel(q); // starts a batch
27-
// CHECK: ---> urEnqueueKernelLaunch
27+
// CHECK: ---> urEnqueueKernelLaunchWithArgsExp
2828
// CHECK-NOT: zeCommandQueueExecuteCommandLists
2929

3030
// Initializing Level Zero driver is required if this test is linked
@@ -42,7 +42,7 @@ int main(int argc, char *argv[]) {
4242
// CHECK-NOT: zeCommandQueueExecuteCommandLists
4343

4444
submit_kernel(q);
45-
// CHECK: ---> urEnqueueKernelLaunch
45+
// CHECK: ---> urEnqueueKernelLaunchWithArgsExp
4646
// CHECK-NOT: zeCommandQueueExecuteCommandLists
4747

4848
// interop should close the batch

0 commit comments

Comments
 (0)