Skip to content

Commit 6b492be

Browse files
committed
Update interface load with new api.
- Move new function to enqueue.cpp. - Fix impl/tests. - Add device extension string. - Clean up documentation. Signed-off-by: JackAKirk <[email protected]>
1 parent a818d50 commit 6b492be

File tree

6 files changed

+187
-173
lines changed

6 files changed

+187
-173
lines changed

scripts/core/EXP-LAUNCH-ATTRIBUTES.rst

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,17 @@ dimension over which the shared memory is accessible. Additionally some
3535
applications require specification of kernel attributes at runtime.
3636

3737
This extension is a future-proof and portable solution that supports these two requirements.
38-
Instead of using a fixed set of kernel arguments, the approach is to introduce
39-
"exp_launch_attribute_t" that enables a more flexible approach.
40-
Each exp_launch_attr_t corresponds to a specific kernel launch attribute.
41-
One new function is introduced. "urEnqueueKernelLaunchCustomExp" takes an
42-
array of launch_attribute_t as an argument, and launches a kernel using these
43-
attributes. "urEnqueueKernelLaunchCustomExp" corresponds to the CUDA Driver API
44-
"cuLaunchKernelEx".
45-
46-
Many kernel properties could be supported, such as cooperative kernels. As such,
38+
Instead of using a fixed set of kernel arguments, the approach is to introduce the
39+
`exp_launch_attribute_t` type that enables a more flexible approach.
40+
Each `exp_launch_attribute_t` instance corresponds to a specific kernel launch attribute.
41+
One new function is introduced; `urEnqueueKernelLaunchCustomExp` takes an
42+
array of `exp_launch_attribute_t` as an argument, and launches a kernel using these
43+
attributes. `urEnqueueKernelLaunchCustomExp` corresponds closely to the CUDA Driver API
44+
`cuLaunchKernelEx`.
45+
46+
Many kernel properties can be supported, such as cooperative kernels. As such,
4747
eventually this extension should be able to replace the cooperative kernels
48-
UR extension.
48+
UR extension.
4949

5050
API
5151
--------------------------------------------------------------------------------

source/adapters/cuda/device.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
628628
// Return supported for the UR command-buffer experimental feature
629629
SupportedExtensions += "ur_exp_command_buffer ";
630630
SupportedExtensions += "ur_exp_usm_p2p ";
631+
SupportedExtensions += "ur_exp_launch_attributes ";
631632
SupportedExtensions += " ";
632633

633634
int Major = 0;

source/adapters/cuda/enqueue.cpp

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,133 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
484484
numEventsInWaitList, phEventWaitList, phEvent);
485485
}
486486

487+
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
488+
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
489+
const size_t *pGlobalWorkSize, const size_t *pLocalWorkSize,
490+
uint32_t numAttrsInLaunchAttrList,
491+
const ur_exp_launch_attribute_t *launchAttrList,
492+
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
493+
ur_event_handle_t *phEvent) {
494+
495+
if (numAttrsInLaunchAttrList == 0) {
496+
urEnqueueKernelLaunch(hQueue, hKernel, workDim, nullptr, pGlobalWorkSize,
497+
pLocalWorkSize, numEventsInWaitList, phEventWaitList,
498+
phEvent);
499+
}
500+
501+
// Preconditions
502+
UR_ASSERT(hQueue->getContext() == hKernel->getContext(),
503+
UR_RESULT_ERROR_INVALID_KERNEL);
504+
UR_ASSERT(workDim > 0, UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
505+
UR_ASSERT(workDim < 4, UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
506+
507+
if (launchAttrList == NULL) {
508+
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
509+
}
510+
511+
std::vector<CUlaunchAttribute> launch_attribute(numAttrsInLaunchAttrList);
512+
for (uint32_t i = 0; i < numAttrsInLaunchAttrList; i++) {
513+
switch (launchAttrList[i].id) {
514+
case UR_EXP_LAUNCH_ATTRIBUTE_ID_IGNORE: {
515+
launch_attribute[i].id = CU_LAUNCH_ATTRIBUTE_IGNORE;
516+
break;
517+
}
518+
case UR_EXP_LAUNCH_ATTRIBUTE_ID_CLUSTER_DIMENSION: {
519+
520+
launch_attribute[i].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
521+
launch_attribute[i].value.clusterDim.x =
522+
launchAttrList[i].value.clusterDim[0];
523+
launch_attribute[i].value.clusterDim.y =
524+
launchAttrList[i].value.clusterDim[1];
525+
launch_attribute[i].value.clusterDim.z =
526+
launchAttrList[i].value.clusterDim[2];
527+
break;
528+
}
529+
case UR_EXP_LAUNCH_ATTRIBUTE_ID_COOPERATIVE: {
530+
launch_attribute[i].id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE;
531+
launch_attribute[i].value.cooperative =
532+
launchAttrList[i].value.cooperative;
533+
break;
534+
}
535+
default: {
536+
return UR_RESULT_ERROR_INVALID_ENUMERATION;
537+
}
538+
}
539+
}
540+
541+
if (*pGlobalWorkSize == 0) {
542+
return urEnqueueEventsWaitWithBarrier(hQueue, numEventsInWaitList,
543+
phEventWaitList, phEvent);
544+
}
545+
546+
// Set the number of threads per block to the number of threads per warp
547+
// by default unless user has provided a better number
548+
size_t ThreadsPerBlock[3] = {32u, 1u, 1u};
549+
size_t BlocksPerGrid[3] = {1u, 1u, 1u};
550+
551+
uint32_t LocalSize = hKernel->getLocalSize();
552+
ur_result_t Result = UR_RESULT_SUCCESS;
553+
CUfunction CuFunc = hKernel->get();
554+
555+
Result = setKernelParams(hQueue->getContext(), hQueue->Device, workDim,
556+
nullptr, pGlobalWorkSize, pLocalWorkSize, hKernel,
557+
CuFunc, ThreadsPerBlock, BlocksPerGrid);
558+
if (Result != UR_RESULT_SUCCESS) {
559+
return Result;
560+
}
561+
562+
try {
563+
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
564+
565+
uint32_t StreamToken;
566+
ur_stream_guard_ Guard;
567+
CUstream CuStream = hQueue->getNextComputeStream(
568+
numEventsInWaitList, phEventWaitList, Guard, &StreamToken);
569+
570+
Result = enqueueEventsWait(hQueue, CuStream, numEventsInWaitList,
571+
phEventWaitList);
572+
573+
if (phEvent) {
574+
RetImplEvent =
575+
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
576+
UR_COMMAND_KERNEL_LAUNCH, hQueue, CuStream, StreamToken));
577+
UR_CHECK_ERROR(RetImplEvent->start());
578+
}
579+
580+
auto &ArgIndices = hKernel->getArgIndices();
581+
582+
CUlaunchConfig launch_config;
583+
launch_config.gridDimX = BlocksPerGrid[0];
584+
launch_config.gridDimY = BlocksPerGrid[1];
585+
launch_config.gridDimZ = BlocksPerGrid[2];
586+
launch_config.blockDimX = ThreadsPerBlock[0];
587+
launch_config.blockDimY = ThreadsPerBlock[1];
588+
launch_config.blockDimZ = ThreadsPerBlock[2];
589+
590+
launch_config.sharedMemBytes = LocalSize;
591+
launch_config.hStream = CuStream;
592+
launch_config.attrs = &launch_attribute[0];
593+
launch_config.numAttrs = numAttrsInLaunchAttrList;
594+
595+
UR_CHECK_ERROR(cuLaunchKernelEx(&launch_config, CuFunc,
596+
const_cast<void **>(ArgIndices.data()),
597+
nullptr));
598+
599+
if (LocalSize != 0)
600+
hKernel->clearLocalSize();
601+
602+
if (phEvent) {
603+
UR_CHECK_ERROR(RetImplEvent->record());
604+
*phEvent = RetImplEvent.release();
605+
}
606+
607+
} catch (ur_result_t Err) {
608+
Result = Err;
609+
}
610+
return Result;
611+
}
612+
613+
487614
/// Set parameters for general 3D memory copy.
488615
/// If the source and/or destination is on the device, SrcPtr and/or DstPtr
489616
/// must be a pointer to a CUdeviceptr

source/adapters/cuda/launch_attributes.cpp

Lines changed: 0 additions & 133 deletions
This file was deleted.

source/adapters/cuda/ur_interface_loader.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueExpProcAddrTable(
408408
pDdiTable->pfnCooperativeKernelLaunchExp =
409409
urEnqueueCooperativeKernelLaunchExp;
410410
pDdiTable->pfnTimestampRecordingExp = urEnqueueTimestampRecordingExp;
411+
pDdiTable->pfnKernelLaunchCustomExp = urEnqueueKernelLaunchCustomExp;
411412

412413
return UR_RESULT_SUCCESS;
413414
}

0 commit comments

Comments
 (0)