Skip to content

Commit 24f54ea

Browse files
authored
[UR][SYCL] Introduce UR api to set kernel args + launch in one call. (#18764)
Introduces `urEnqueueKernelLaunchWithArgsExp`, a new UR entry point which allows combining `KernelSetArg` and `EnqueueKernelLaunch` calls into one. Replaces `urEnqueueKernelLaunch` as the default kernel launch path in sycl RT.
1 parent d6928aa commit 24f54ea

File tree

85 files changed

+3356
-422
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

85 files changed

+3356
-422
lines changed

sycl/source/detail/scheduler/commands.cpp

Lines changed: 157 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2315,14 +2315,14 @@ ur_mem_flags_t AccessModeToUr(access::mode AccessorMode) {
23152315
}
23162316
}
23172317

2318-
// Sets arguments for a given kernel and device based on the argument type.
2319-
// Refactored from SetKernelParamsAndLaunch to allow it to be used in the graphs
2320-
// extension.
2321-
static void SetArgBasedOnType(
2322-
adapter_impl &Adapter, ur_kernel_handle_t Kernel,
2318+
// Gets UR argument struct for a given kernel and device based on the argument
2319+
// type. Refactored from SetKernelParamsAndLaunch to allow it to be used in
2320+
// the graphs extension (LaunchWithArgs for graphs is planned future work).
2321+
static void GetUrArgsBasedOnType(
23232322
device_image_impl *DeviceImageImpl,
23242323
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2325-
context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
2324+
context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex,
2325+
std::vector<ur_exp_kernel_arg_properties_t> &UrArgs) {
23262326
switch (Arg.MType) {
23272327
case kernel_param_kind_t::kind_dynamic_work_group_memory:
23282328
break;
@@ -2342,52 +2342,61 @@ static void SetArgBasedOnType(
23422342
getMemAllocationFunc
23432343
? reinterpret_cast<ur_mem_handle_t>(getMemAllocationFunc(Req))
23442344
: nullptr;
2345-
ur_kernel_arg_mem_obj_properties_t MemObjData{};
2346-
MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2347-
MemObjData.memoryAccess = AccessModeToUr(Req->MAccessMode);
2348-
Adapter.call<UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2349-
&MemObjData, MemArg);
2345+
ur_exp_kernel_arg_value_t Value = {};
2346+
Value.memObjTuple = {MemArg, AccessModeToUr(Req->MAccessMode)};
2347+
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
2348+
UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
2349+
static_cast<uint32_t>(NextTrueIndex), sizeof(MemArg),
2350+
Value});
23502351
break;
23512352
}
23522353
case kernel_param_kind_t::kind_std_layout: {
2354+
ur_exp_kernel_arg_type_t Type;
23532355
if (Arg.MPtr) {
2354-
Adapter.call<UrApiKind::urKernelSetArgValue>(
2355-
Kernel, NextTrueIndex, Arg.MSize, nullptr, Arg.MPtr);
2356+
Type = UR_EXP_KERNEL_ARG_TYPE_VALUE;
23562357
} else {
2357-
Adapter.call<UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2358-
Arg.MSize, nullptr);
2358+
Type = UR_EXP_KERNEL_ARG_TYPE_LOCAL;
23592359
}
2360+
ur_exp_kernel_arg_value_t Value = {};
2361+
Value.value = {Arg.MPtr};
2362+
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
2363+
Type, static_cast<uint32_t>(NextTrueIndex),
2364+
static_cast<size_t>(Arg.MSize), Value});
23602365

23612366
break;
23622367
}
23632368
case kernel_param_kind_t::kind_sampler: {
23642369
sampler *SamplerPtr = (sampler *)Arg.MPtr;
2365-
ur_sampler_handle_t Sampler =
2366-
(ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr)
2367-
->getOrCreateSampler(ContextImpl);
2368-
Adapter.call<UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2369-
nullptr, Sampler);
2370+
ur_exp_kernel_arg_value_t Value = {};
2371+
Value.sampler = (ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr)
2372+
->getOrCreateSampler(ContextImpl);
2373+
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
2374+
UR_EXP_KERNEL_ARG_TYPE_SAMPLER,
2375+
static_cast<uint32_t>(NextTrueIndex),
2376+
sizeof(ur_sampler_handle_t), Value});
23702377
break;
23712378
}
23722379
case kernel_param_kind_t::kind_pointer: {
2373-
// We need to de-rerence this to get the actual USM allocation - that's the
2380+
ur_exp_kernel_arg_value_t Value = {};
2381+
// We need to de-rerence to get the actual USM allocation - that's the
23742382
// pointer UR is expecting.
2375-
const void *Ptr = *static_cast<const void *const *>(Arg.MPtr);
2376-
Adapter.call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2377-
nullptr, Ptr);
2383+
Value.pointer = *static_cast<void *const *>(Arg.MPtr);
2384+
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
2385+
UR_EXP_KERNEL_ARG_TYPE_POINTER,
2386+
static_cast<uint32_t>(NextTrueIndex), sizeof(Arg.MPtr),
2387+
Value});
23782388
break;
23792389
}
23802390
case kernel_param_kind_t::kind_specialization_constants_buffer: {
23812391
assert(DeviceImageImpl != nullptr);
23822392
ur_mem_handle_t SpecConstsBuffer =
23832393
DeviceImageImpl->get_spec_const_buffer_ref();
2384-
2385-
ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2386-
MemObjProps.pNext = nullptr;
2387-
MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2388-
MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2389-
Adapter.call<UrApiKind::urKernelSetArgMemObj>(
2390-
Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
2394+
ur_exp_kernel_arg_value_t Value = {};
2395+
Value.memObjTuple = {SpecConstsBuffer, UR_MEM_FLAG_READ_ONLY};
2396+
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
2397+
UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
2398+
static_cast<uint32_t>(NextTrueIndex),
2399+
sizeof(SpecConstsBuffer), Value});
23912400
break;
23922401
}
23932402
case kernel_param_kind_t::kind_invalid:
@@ -2420,22 +2429,32 @@ static ur_result_t SetKernelParamsAndLaunch(
24202429
DeviceImageImpl ? DeviceImageImpl->get_spec_const_blob_ref() : Empty);
24212430
}
24222431

2432+
std::vector<ur_exp_kernel_arg_properties_t> UrArgs;
2433+
UrArgs.reserve(Args.size());
2434+
24232435
if (KernelFuncPtr && !KernelHasSpecialCaptures) {
2424-
auto setFunc = [&Adapter, Kernel,
2436+
auto setFunc = [&UrArgs,
24252437
KernelFuncPtr](const detail::kernel_param_desc_t &ParamDesc,
24262438
size_t NextTrueIndex) {
24272439
const void *ArgPtr = (const char *)KernelFuncPtr + ParamDesc.offset;
24282440
switch (ParamDesc.kind) {
24292441
case kernel_param_kind_t::kind_std_layout: {
24302442
int Size = ParamDesc.info;
2431-
Adapter.call<UrApiKind::urKernelSetArgValue>(Kernel, NextTrueIndex,
2432-
Size, nullptr, ArgPtr);
2443+
ur_exp_kernel_arg_value_t Value = {};
2444+
Value.value = ArgPtr;
2445+
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
2446+
UR_EXP_KERNEL_ARG_TYPE_VALUE,
2447+
static_cast<uint32_t>(NextTrueIndex),
2448+
static_cast<size_t>(Size), Value});
24332449
break;
24342450
}
24352451
case kernel_param_kind_t::kind_pointer: {
2436-
const void *Ptr = *static_cast<const void *const *>(ArgPtr);
2437-
Adapter.call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2438-
nullptr, Ptr);
2452+
ur_exp_kernel_arg_value_t Value = {};
2453+
Value.pointer = *static_cast<const void *const *>(ArgPtr);
2454+
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
2455+
UR_EXP_KERNEL_ARG_TYPE_POINTER,
2456+
static_cast<uint32_t>(NextTrueIndex),
2457+
sizeof(Value.pointer), Value});
24392458
break;
24402459
}
24412460
default:
@@ -2445,10 +2464,10 @@ static ur_result_t SetKernelParamsAndLaunch(
24452464
applyFuncOnFilteredArgs(EliminatedArgMask, KernelNumArgs,
24462465
KernelParamDescGetter, setFunc);
24472466
} else {
2448-
auto setFunc = [&Adapter, Kernel, &DeviceImageImpl, &getMemAllocationFunc,
2449-
&Queue](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2450-
SetArgBasedOnType(Adapter, Kernel, DeviceImageImpl, getMemAllocationFunc,
2451-
Queue.getContextImpl(), Arg, NextTrueIndex);
2467+
auto setFunc = [&DeviceImageImpl, &getMemAllocationFunc, &Queue,
2468+
&UrArgs](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2469+
GetUrArgsBasedOnType(DeviceImageImpl, getMemAllocationFunc,
2470+
Queue.getContextImpl(), Arg, NextTrueIndex, UrArgs);
24522471
};
24532472
applyFuncOnFilteredArgs(EliminatedArgMask, Args, setFunc);
24542473
}
@@ -2461,8 +2480,12 @@ static ur_result_t SetKernelParamsAndLaunch(
24612480
// CUDA-style local memory setting. Note that we may have -1 as a position,
24622481
// this indicates the buffer is actually unused and was elided.
24632482
if (ImplicitLocalArg.has_value() && ImplicitLocalArg.value() != -1) {
2464-
Adapter.call<UrApiKind::urKernelSetArgLocal>(
2465-
Kernel, ImplicitLocalArg.value(), WorkGroupMemorySize, nullptr);
2483+
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2484+
nullptr,
2485+
UR_EXP_KERNEL_ARG_TYPE_LOCAL,
2486+
static_cast<uint32_t>(ImplicitLocalArg.value()),
2487+
WorkGroupMemorySize,
2488+
{nullptr}});
24662489
}
24672490

24682491
adjustNDRangePerKernel(NDRDesc, Kernel, Queue.getDeviceImpl());
@@ -2520,20 +2543,104 @@ static ur_result_t SetKernelParamsAndLaunch(
25202543
{{WorkGroupMemorySize}}});
25212544
}
25222545
ur_event_handle_t UREvent = nullptr;
2523-
ur_result_t Error = Adapter.call_nocheck<UrApiKind::urEnqueueKernelLaunch>(
2524-
Queue.getHandleRef(), Kernel, NDRDesc.Dims,
2525-
HasOffset ? &NDRDesc.GlobalOffset[0] : nullptr, &NDRDesc.GlobalSize[0],
2526-
LocalSize, property_list.size(),
2527-
property_list.empty() ? nullptr : property_list.data(), RawEvents.size(),
2528-
RawEvents.empty() ? nullptr : &RawEvents[0],
2529-
OutEventImpl ? &UREvent : nullptr);
2546+
ur_result_t Error =
2547+
Adapter.call_nocheck<UrApiKind::urEnqueueKernelLaunchWithArgsExp>(
2548+
Queue.getHandleRef(), Kernel, NDRDesc.Dims,
2549+
HasOffset ? &NDRDesc.GlobalOffset[0] : nullptr,
2550+
&NDRDesc.GlobalSize[0], LocalSize, UrArgs.size(), UrArgs.data(),
2551+
property_list.size(),
2552+
property_list.empty() ? nullptr : property_list.data(),
2553+
RawEvents.size(), RawEvents.empty() ? nullptr : &RawEvents[0],
2554+
OutEventImpl ? &UREvent : nullptr);
25302555
if (Error == UR_RESULT_SUCCESS && OutEventImpl) {
25312556
OutEventImpl->setHandle(UREvent);
25322557
}
25332558

25342559
return Error;
25352560
}
25362561

2562+
// Sets arguments for a given kernel and device based on the argument type.
2563+
// This is a legacy path which the graphs extension still uses.
2564+
static void SetArgBasedOnType(
2565+
adapter_impl &Adapter, ur_kernel_handle_t Kernel,
2566+
device_image_impl *DeviceImageImpl,
2567+
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2568+
context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
2569+
switch (Arg.MType) {
2570+
case kernel_param_kind_t::kind_dynamic_work_group_memory:
2571+
break;
2572+
case kernel_param_kind_t::kind_work_group_memory:
2573+
break;
2574+
case kernel_param_kind_t::kind_stream:
2575+
break;
2576+
case kernel_param_kind_t::kind_dynamic_accessor:
2577+
case kernel_param_kind_t::kind_accessor: {
2578+
Requirement *Req = (Requirement *)(Arg.MPtr);
2579+
2580+
// getMemAllocationFunc is nullptr when there are no requirements. However,
2581+
// we may pass default constructed accessors to a command, which don't add
2582+
// requirements. In such case, getMemAllocationFunc is nullptr, but it's a
2583+
// valid case, so we need to properly handle it.
2584+
ur_mem_handle_t MemArg =
2585+
getMemAllocationFunc
2586+
? reinterpret_cast<ur_mem_handle_t>(getMemAllocationFunc(Req))
2587+
: nullptr;
2588+
ur_kernel_arg_mem_obj_properties_t MemObjData{};
2589+
MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2590+
MemObjData.memoryAccess = AccessModeToUr(Req->MAccessMode);
2591+
Adapter.call<UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2592+
&MemObjData, MemArg);
2593+
break;
2594+
}
2595+
case kernel_param_kind_t::kind_std_layout: {
2596+
if (Arg.MPtr) {
2597+
Adapter.call<UrApiKind::urKernelSetArgValue>(
2598+
Kernel, NextTrueIndex, Arg.MSize, nullptr, Arg.MPtr);
2599+
} else {
2600+
Adapter.call<UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2601+
Arg.MSize, nullptr);
2602+
}
2603+
2604+
break;
2605+
}
2606+
case kernel_param_kind_t::kind_sampler: {
2607+
sampler *SamplerPtr = (sampler *)Arg.MPtr;
2608+
ur_sampler_handle_t Sampler =
2609+
(ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr)
2610+
->getOrCreateSampler(ContextImpl);
2611+
Adapter.call<UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2612+
nullptr, Sampler);
2613+
break;
2614+
}
2615+
case kernel_param_kind_t::kind_pointer: {
2616+
// We need to de-rerence this to get the actual USM allocation - that's the
2617+
// pointer UR is expecting.
2618+
const void *Ptr = *static_cast<const void *const *>(Arg.MPtr);
2619+
Adapter.call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2620+
nullptr, Ptr);
2621+
break;
2622+
}
2623+
case kernel_param_kind_t::kind_specialization_constants_buffer: {
2624+
assert(DeviceImageImpl != nullptr);
2625+
ur_mem_handle_t SpecConstsBuffer =
2626+
DeviceImageImpl->get_spec_const_buffer_ref();
2627+
2628+
ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2629+
MemObjProps.pNext = nullptr;
2630+
MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2631+
MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2632+
Adapter.call<UrApiKind::urKernelSetArgMemObj>(
2633+
Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
2634+
break;
2635+
}
2636+
case kernel_param_kind_t::kind_invalid:
2637+
throw sycl::exception(sycl::make_error_code(sycl::errc::runtime),
2638+
"Invalid kernel param kind " +
2639+
codeToString(UR_RESULT_ERROR_INVALID_VALUE));
2640+
break;
2641+
}
2642+
}
2643+
25372644
static std::tuple<ur_kernel_handle_t, device_image_impl *,
25382645
const KernelArgMask *>
25392646
getCGKernelInfo(const CGExecKernel &CommandGroup, context_impl &ContextImpl,

sycl/test-e2e/Adapters/level_zero/batch_barrier.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ int main(int argc, char *argv[]) {
2424
queue q;
2525

2626
submit_kernel(q); // starts a batch
27-
// CHECK: ---> urEnqueueKernelLaunch
27+
// CHECK: ---> urEnqueueKernelLaunchWithArgsExp
2828
// CHECK-NOT: zeCommandQueueExecuteCommandLists
2929

3030
// Initialize Level Zero driver is required if this test is linked
@@ -41,7 +41,7 @@ int main(int argc, char *argv[]) {
4141
// CHECK-NOT: zeCommandQueueExecuteCommandLists
4242

4343
submit_kernel(q);
44-
// CHECK: ---> urEnqueueKernelLaunch
44+
// CHECK: ---> urEnqueueKernelLaunchWithArgsExp
4545
// CHECK-NOT: zeCommandQueueExecuteCommandLists
4646

4747
// interop should close the batch

0 commit comments

Comments
 (0)