@@ -2315,14 +2315,14 @@ ur_mem_flags_t AccessModeToUr(access::mode AccessorMode) {
2315
2315
}
2316
2316
}
2317
2317
2318
- // Sets arguments for a given kernel and device based on the argument type.
2319
- // Refactored from SetKernelParamsAndLaunch to allow it to be used in the graphs
2320
- // extension.
2321
- static void SetArgBasedOnType (
2322
- adapter_impl &Adapter, ur_kernel_handle_t Kernel,
2318
+ // Gets UR argument struct for a given kernel and device based on the argument
2319
+ // type. Refactored from SetKernelParamsAndLaunch to allow it to be used in
2320
+ // the graphs extension (LaunchWithArgs for graphs is planned future work).
2321
+ static void GetUrArgsBasedOnType (
2323
2322
device_image_impl *DeviceImageImpl,
2324
2323
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2325
- context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
2324
+ context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex,
2325
+ std::vector<ur_exp_kernel_arg_properties_t> &UrArgs) {
2326
2326
switch (Arg.MType ) {
2327
2327
case kernel_param_kind_t ::kind_dynamic_work_group_memory:
2328
2328
break ;
@@ -2342,52 +2342,61 @@ static void SetArgBasedOnType(
2342
2342
getMemAllocationFunc
2343
2343
? reinterpret_cast <ur_mem_handle_t >(getMemAllocationFunc (Req))
2344
2344
: nullptr ;
2345
- ur_kernel_arg_mem_obj_properties_t MemObjData{};
2346
- MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2347
- MemObjData.memoryAccess = AccessModeToUr (Req->MAccessMode );
2348
- Adapter.call <UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2349
- &MemObjData, MemArg);
2345
+ ur_exp_kernel_arg_value_t Value = {};
2346
+ Value.memObjTuple = {MemArg, AccessModeToUr (Req->MAccessMode )};
2347
+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2348
+ UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
2349
+ static_cast <uint32_t >(NextTrueIndex), sizeof (MemArg),
2350
+ Value});
2350
2351
break ;
2351
2352
}
2352
2353
case kernel_param_kind_t ::kind_std_layout: {
2354
+ ur_exp_kernel_arg_type_t Type;
2353
2355
if (Arg.MPtr ) {
2354
- Adapter.call <UrApiKind::urKernelSetArgValue>(
2355
- Kernel, NextTrueIndex, Arg.MSize , nullptr , Arg.MPtr );
2356
+ Type = UR_EXP_KERNEL_ARG_TYPE_VALUE;
2356
2357
} else {
2357
- Adapter.call <UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2358
- Arg.MSize , nullptr );
2358
+ Type = UR_EXP_KERNEL_ARG_TYPE_LOCAL;
2359
2359
}
2360
+ ur_exp_kernel_arg_value_t Value = {};
2361
+ Value.value = {Arg.MPtr };
2362
+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2363
+ Type, static_cast <uint32_t >(NextTrueIndex),
2364
+ static_cast <size_t >(Arg.MSize ), Value});
2360
2365
2361
2366
break ;
2362
2367
}
2363
2368
case kernel_param_kind_t ::kind_sampler: {
2364
2369
sampler *SamplerPtr = (sampler *)Arg.MPtr ;
2365
- ur_sampler_handle_t Sampler =
2366
- (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2367
- ->getOrCreateSampler (ContextImpl);
2368
- Adapter.call <UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2369
- nullptr , Sampler);
2370
+ ur_exp_kernel_arg_value_t Value = {};
2371
+ Value.sampler = (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2372
+ ->getOrCreateSampler (ContextImpl);
2373
+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2374
+ UR_EXP_KERNEL_ARG_TYPE_SAMPLER,
2375
+ static_cast <uint32_t >(NextTrueIndex),
2376
+ sizeof (ur_sampler_handle_t ), Value});
2370
2377
break ;
2371
2378
}
2372
2379
case kernel_param_kind_t ::kind_pointer: {
2373
- // We need to de-rerence this to get the actual USM allocation - that's the
2380
+ ur_exp_kernel_arg_value_t Value = {};
2381
+ // We need to de-rerence to get the actual USM allocation - that's the
2374
2382
// pointer UR is expecting.
2375
- const void *Ptr = *static_cast <const void *const *>(Arg.MPtr );
2376
- Adapter.call <UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2377
- nullptr , Ptr);
2383
+ Value.pointer = *static_cast <void *const *>(Arg.MPtr );
2384
+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2385
+ UR_EXP_KERNEL_ARG_TYPE_POINTER,
2386
+ static_cast <uint32_t >(NextTrueIndex), sizeof (Arg.MPtr ),
2387
+ Value});
2378
2388
break ;
2379
2389
}
2380
2390
case kernel_param_kind_t ::kind_specialization_constants_buffer: {
2381
2391
assert (DeviceImageImpl != nullptr );
2382
2392
ur_mem_handle_t SpecConstsBuffer =
2383
2393
DeviceImageImpl->get_spec_const_buffer_ref ();
2384
-
2385
- ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2386
- MemObjProps.pNext = nullptr ;
2387
- MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2388
- MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2389
- Adapter.call <UrApiKind::urKernelSetArgMemObj>(
2390
- Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
2394
+ ur_exp_kernel_arg_value_t Value = {};
2395
+ Value.memObjTuple = {SpecConstsBuffer, UR_MEM_FLAG_READ_ONLY};
2396
+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2397
+ UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
2398
+ static_cast <uint32_t >(NextTrueIndex),
2399
+ sizeof (SpecConstsBuffer), Value});
2391
2400
break ;
2392
2401
}
2393
2402
case kernel_param_kind_t ::kind_invalid:
@@ -2420,22 +2429,32 @@ static ur_result_t SetKernelParamsAndLaunch(
2420
2429
DeviceImageImpl ? DeviceImageImpl->get_spec_const_blob_ref () : Empty);
2421
2430
}
2422
2431
2432
+ std::vector<ur_exp_kernel_arg_properties_t > UrArgs;
2433
+ UrArgs.reserve (Args.size ());
2434
+
2423
2435
if (KernelFuncPtr && !KernelHasSpecialCaptures) {
2424
- auto setFunc = [&Adapter, Kernel ,
2436
+ auto setFunc = [&UrArgs ,
2425
2437
KernelFuncPtr](const detail::kernel_param_desc_t &ParamDesc,
2426
2438
size_t NextTrueIndex) {
2427
2439
const void *ArgPtr = (const char *)KernelFuncPtr + ParamDesc.offset ;
2428
2440
switch (ParamDesc.kind ) {
2429
2441
case kernel_param_kind_t ::kind_std_layout: {
2430
2442
int Size = ParamDesc.info ;
2431
- Adapter.call <UrApiKind::urKernelSetArgValue>(Kernel, NextTrueIndex,
2432
- Size, nullptr , ArgPtr);
2443
+ ur_exp_kernel_arg_value_t Value = {};
2444
+ Value.value = ArgPtr;
2445
+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2446
+ UR_EXP_KERNEL_ARG_TYPE_VALUE,
2447
+ static_cast <uint32_t >(NextTrueIndex),
2448
+ static_cast <size_t >(Size), Value});
2433
2449
break ;
2434
2450
}
2435
2451
case kernel_param_kind_t ::kind_pointer: {
2436
- const void *Ptr = *static_cast <const void *const *>(ArgPtr);
2437
- Adapter.call <UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2438
- nullptr , Ptr);
2452
+ ur_exp_kernel_arg_value_t Value = {};
2453
+ Value.pointer = *static_cast <const void *const *>(ArgPtr);
2454
+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2455
+ UR_EXP_KERNEL_ARG_TYPE_POINTER,
2456
+ static_cast <uint32_t >(NextTrueIndex),
2457
+ sizeof (Value.pointer ), Value});
2439
2458
break ;
2440
2459
}
2441
2460
default :
@@ -2445,10 +2464,10 @@ static ur_result_t SetKernelParamsAndLaunch(
2445
2464
applyFuncOnFilteredArgs (EliminatedArgMask, KernelNumArgs,
2446
2465
KernelParamDescGetter, setFunc);
2447
2466
} else {
2448
- auto setFunc = [&Adapter, Kernel, &DeviceImageImpl , &getMemAllocationFunc ,
2449
- &Queue ](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2450
- SetArgBasedOnType (Adapter, Kernel, DeviceImageImpl, getMemAllocationFunc,
2451
- Queue.getContextImpl (), Arg, NextTrueIndex);
2467
+ auto setFunc = [&DeviceImageImpl, &getMemAllocationFunc , &Queue ,
2468
+ &UrArgs ](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2469
+ GetUrArgsBasedOnType ( DeviceImageImpl, getMemAllocationFunc,
2470
+ Queue.getContextImpl (), Arg, NextTrueIndex, UrArgs );
2452
2471
};
2453
2472
applyFuncOnFilteredArgs (EliminatedArgMask, Args, setFunc);
2454
2473
}
@@ -2461,8 +2480,12 @@ static ur_result_t SetKernelParamsAndLaunch(
2461
2480
// CUDA-style local memory setting. Note that we may have -1 as a position,
2462
2481
// this indicates the buffer is actually unused and was elided.
2463
2482
if (ImplicitLocalArg.has_value () && ImplicitLocalArg.value () != -1 ) {
2464
- Adapter.call <UrApiKind::urKernelSetArgLocal>(
2465
- Kernel, ImplicitLocalArg.value (), WorkGroupMemorySize, nullptr );
2483
+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2484
+ nullptr ,
2485
+ UR_EXP_KERNEL_ARG_TYPE_LOCAL,
2486
+ static_cast <uint32_t >(ImplicitLocalArg.value ()),
2487
+ WorkGroupMemorySize,
2488
+ {nullptr }});
2466
2489
}
2467
2490
2468
2491
adjustNDRangePerKernel (NDRDesc, Kernel, Queue.getDeviceImpl ());
@@ -2520,20 +2543,104 @@ static ur_result_t SetKernelParamsAndLaunch(
2520
2543
{{WorkGroupMemorySize}}});
2521
2544
}
2522
2545
ur_event_handle_t UREvent = nullptr ;
2523
- ur_result_t Error = Adapter.call_nocheck <UrApiKind::urEnqueueKernelLaunch>(
2524
- Queue.getHandleRef (), Kernel, NDRDesc.Dims ,
2525
- HasOffset ? &NDRDesc.GlobalOffset [0 ] : nullptr , &NDRDesc.GlobalSize [0 ],
2526
- LocalSize, property_list.size (),
2527
- property_list.empty () ? nullptr : property_list.data (), RawEvents.size (),
2528
- RawEvents.empty () ? nullptr : &RawEvents[0 ],
2529
- OutEventImpl ? &UREvent : nullptr );
2546
+ ur_result_t Error =
2547
+ Adapter.call_nocheck <UrApiKind::urEnqueueKernelLaunchWithArgsExp>(
2548
+ Queue.getHandleRef (), Kernel, NDRDesc.Dims ,
2549
+ HasOffset ? &NDRDesc.GlobalOffset [0 ] : nullptr ,
2550
+ &NDRDesc.GlobalSize [0 ], LocalSize, UrArgs.size (), UrArgs.data (),
2551
+ property_list.size (),
2552
+ property_list.empty () ? nullptr : property_list.data (),
2553
+ RawEvents.size (), RawEvents.empty () ? nullptr : &RawEvents[0 ],
2554
+ OutEventImpl ? &UREvent : nullptr );
2530
2555
if (Error == UR_RESULT_SUCCESS && OutEventImpl) {
2531
2556
OutEventImpl->setHandle (UREvent);
2532
2557
}
2533
2558
2534
2559
return Error;
2535
2560
}
2536
2561
2562
+ // Sets arguments for a given kernel and device based on the argument type.
2563
+ // This is a legacy path which the graphs extension still uses.
2564
+ static void SetArgBasedOnType (
2565
+ adapter_impl &Adapter, ur_kernel_handle_t Kernel,
2566
+ device_image_impl *DeviceImageImpl,
2567
+ const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2568
+ context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
2569
+ switch (Arg.MType ) {
2570
+ case kernel_param_kind_t ::kind_dynamic_work_group_memory:
2571
+ break ;
2572
+ case kernel_param_kind_t ::kind_work_group_memory:
2573
+ break ;
2574
+ case kernel_param_kind_t ::kind_stream:
2575
+ break ;
2576
+ case kernel_param_kind_t ::kind_dynamic_accessor:
2577
+ case kernel_param_kind_t ::kind_accessor: {
2578
+ Requirement *Req = (Requirement *)(Arg.MPtr );
2579
+
2580
+ // getMemAllocationFunc is nullptr when there are no requirements. However,
2581
+ // we may pass default constructed accessors to a command, which don't add
2582
+ // requirements. In such case, getMemAllocationFunc is nullptr, but it's a
2583
+ // valid case, so we need to properly handle it.
2584
+ ur_mem_handle_t MemArg =
2585
+ getMemAllocationFunc
2586
+ ? reinterpret_cast <ur_mem_handle_t >(getMemAllocationFunc (Req))
2587
+ : nullptr ;
2588
+ ur_kernel_arg_mem_obj_properties_t MemObjData{};
2589
+ MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2590
+ MemObjData.memoryAccess = AccessModeToUr (Req->MAccessMode );
2591
+ Adapter.call <UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2592
+ &MemObjData, MemArg);
2593
+ break ;
2594
+ }
2595
+ case kernel_param_kind_t ::kind_std_layout: {
2596
+ if (Arg.MPtr ) {
2597
+ Adapter.call <UrApiKind::urKernelSetArgValue>(
2598
+ Kernel, NextTrueIndex, Arg.MSize , nullptr , Arg.MPtr );
2599
+ } else {
2600
+ Adapter.call <UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2601
+ Arg.MSize , nullptr );
2602
+ }
2603
+
2604
+ break ;
2605
+ }
2606
+ case kernel_param_kind_t ::kind_sampler: {
2607
+ sampler *SamplerPtr = (sampler *)Arg.MPtr ;
2608
+ ur_sampler_handle_t Sampler =
2609
+ (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2610
+ ->getOrCreateSampler (ContextImpl);
2611
+ Adapter.call <UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2612
+ nullptr , Sampler);
2613
+ break ;
2614
+ }
2615
+ case kernel_param_kind_t ::kind_pointer: {
2616
+ // We need to de-rerence this to get the actual USM allocation - that's the
2617
+ // pointer UR is expecting.
2618
+ const void *Ptr = *static_cast <const void *const *>(Arg.MPtr );
2619
+ Adapter.call <UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2620
+ nullptr , Ptr);
2621
+ break ;
2622
+ }
2623
+ case kernel_param_kind_t ::kind_specialization_constants_buffer: {
2624
+ assert (DeviceImageImpl != nullptr );
2625
+ ur_mem_handle_t SpecConstsBuffer =
2626
+ DeviceImageImpl->get_spec_const_buffer_ref ();
2627
+
2628
+ ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2629
+ MemObjProps.pNext = nullptr ;
2630
+ MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2631
+ MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2632
+ Adapter.call <UrApiKind::urKernelSetArgMemObj>(
2633
+ Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
2634
+ break ;
2635
+ }
2636
+ case kernel_param_kind_t ::kind_invalid:
2637
+ throw sycl::exception (sycl::make_error_code (sycl::errc::runtime),
2638
+ " Invalid kernel param kind " +
2639
+ codeToString (UR_RESULT_ERROR_INVALID_VALUE));
2640
+ break ;
2641
+ }
2642
+ }
2643
+
2537
2644
static std::tuple<ur_kernel_handle_t , device_image_impl *,
2538
2645
const KernelArgMask *>
2539
2646
getCGKernelInfo (const CGExecKernel &CommandGroup, context_impl &ContextImpl,
0 commit comments