@@ -422,11 +422,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWait(
422422 phEventWaitList, phEvent);
423423}
424424
425- UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch (
426- ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
427- const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
428- const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
429- const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
425+ static ur_result_t
426+ enqueueKernelLaunch (ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel,
427+ uint32_t workDim, const size_t *pGlobalWorkOffset,
428+ const size_t *pGlobalWorkSize, const size_t *pLocalWorkSize,
429+ uint32_t numEventsInWaitList,
430+ const ur_event_handle_t *phEventWaitList,
431+ ur_event_handle_t *phEvent, size_t WorkGroupMemory) {
430432 // Preconditions
431433 UR_ASSERT (hQueue->getDevice () == hKernel->getProgram ()->getDevice (),
432434 UR_RESULT_ERROR_INVALID_KERNEL);
@@ -444,6 +446,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
444446 size_t ThreadsPerBlock[3 ] = {32u , 1u , 1u };
445447 size_t BlocksPerGrid[3 ] = {1u , 1u , 1u };
446448
449+ // Set work group memory so we can compute the whole memory requirement
450+ if (WorkGroupMemory)
451+ hKernel->setWorkGroupMemory (WorkGroupMemory);
447452 uint32_t LocalSize = hKernel->getLocalSize ();
448453 CUfunction CuFunc = hKernel->get ();
449454
@@ -506,6 +511,17 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
506511 return UR_RESULT_SUCCESS;
507512}
508513
514+ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch (
515+ ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
516+ const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
517+ const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
518+ const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
519+ return enqueueKernelLaunch (hQueue, hKernel, workDim, pGlobalWorkOffset,
520+ pGlobalWorkSize, pLocalWorkSize,
521+ numEventsInWaitList, phEventWaitList, phEvent,
522+ /* WorkGroupMemory=*/ 0 );
523+ }
524+
509525UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp (
510526 ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
511527 const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
@@ -516,8 +532,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
516532 coop_prop.id = UR_EXP_LAUNCH_PROPERTY_ID_COOPERATIVE;
517533 coop_prop.value .cooperative = 1 ;
518534 return urEnqueueKernelLaunchCustomExp (
519- hQueue, hKernel, workDim, pGlobalWorkSize, pLocalWorkSize, 1 ,
520- &coop_prop, numEventsInWaitList, phEventWaitList, phEvent);
535+ hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize,
536+ pLocalWorkSize, 1 , &coop_prop, numEventsInWaitList, phEventWaitList,
537+ phEvent);
521538 }
522539 return urEnqueueKernelLaunch (hQueue, hKernel, workDim, pGlobalWorkOffset,
523540 pGlobalWorkSize, pLocalWorkSize,
@@ -526,16 +543,29 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
526543
527544UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp (
528545 ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
529- const size_t *pGlobalWorkSize , const size_t *pLocalWorkSize ,
530- uint32_t numPropsInLaunchPropList,
546+ const size_t *pGlobalWorkOffset , const size_t *pGlobalWorkSize ,
547+ const size_t *pLocalWorkSize, uint32_t numPropsInLaunchPropList,
531548 const ur_exp_launch_property_t *launchPropList,
532549 uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
533550 ur_event_handle_t *phEvent) {
534551
535- if (numPropsInLaunchPropList == 0 ) {
536- urEnqueueKernelLaunch (hQueue, hKernel, workDim, nullptr , pGlobalWorkSize,
537- pLocalWorkSize, numEventsInWaitList, phEventWaitList,
538- phEvent);
552+ size_t WorkGroupMemory = [&]() -> size_t {
553+ const ur_exp_launch_property_t *WorkGroupMemoryProp = std::find_if (
554+ launchPropList, launchPropList + numPropsInLaunchPropList,
555+ [](const ur_exp_launch_property_t &Prop) {
556+ return Prop.id == UR_EXP_LAUNCH_PROPERTY_ID_WORK_GROUP_MEMORY;
557+ });
558+ if (WorkGroupMemoryProp != launchPropList + numPropsInLaunchPropList)
559+ return WorkGroupMemoryProp->value .workgroup_mem_size ;
560+ return 0 ;
561+ }();
562+
563+ if (numPropsInLaunchPropList == 0 ||
564+ (WorkGroupMemory && numPropsInLaunchPropList == 1 )) {
565+ return enqueueKernelLaunch (hQueue, hKernel, workDim, pGlobalWorkOffset,
566+ pGlobalWorkSize, pLocalWorkSize,
567+ numEventsInWaitList, phEventWaitList, phEvent,
568+ WorkGroupMemory);
539569 }
540570#if CUDA_VERSION >= 11080
541571 // Preconditions
@@ -548,7 +578,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
548578 return UR_RESULT_ERROR_INVALID_NULL_POINTER;
549579 }
550580
551- std::vector<CUlaunchAttribute> launch_attribute (numPropsInLaunchPropList);
581+ std::vector<CUlaunchAttribute> launch_attribute;
582+ launch_attribute.reserve (numPropsInLaunchPropList);
552583
553584 // Early exit for zero size kernel
554585 if (*pGlobalWorkSize == 0 ) {
@@ -561,40 +592,35 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
561592 size_t ThreadsPerBlock[3 ] = {32u , 1u , 1u };
562593 size_t BlocksPerGrid[3 ] = {1u , 1u , 1u };
563594
595+ // Set work group memory so we can compute the whole memory requirement
596+ if (WorkGroupMemory)
597+ hKernel->setWorkGroupMemory (WorkGroupMemory);
564598 uint32_t LocalSize = hKernel->getLocalSize ();
565599 CUfunction CuFunc = hKernel->get ();
566600
567601 for (uint32_t i = 0 ; i < numPropsInLaunchPropList; i++) {
568602 switch (launchPropList[i].id ) {
569603 case UR_EXP_LAUNCH_PROPERTY_ID_IGNORE: {
570- launch_attribute[i].id = CU_LAUNCH_ATTRIBUTE_IGNORE;
604+ auto &attr = launch_attribute.emplace_back ();
605+ attr.id = CU_LAUNCH_ATTRIBUTE_IGNORE;
571606 break ;
572607 }
573608 case UR_EXP_LAUNCH_PROPERTY_ID_CLUSTER_DIMENSION: {
574-
575- launch_attribute[i] .id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
609+ auto &attr = launch_attribute. emplace_back ();
610+ attr .id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
576611 // Note that cuda orders from right to left wrt SYCL dimensional order.
577612 if (workDim == 3 ) {
578- launch_attribute[i].value .clusterDim .x =
579- launchPropList[i].value .clusterDim [2 ];
580- launch_attribute[i].value .clusterDim .y =
581- launchPropList[i].value .clusterDim [1 ];
582- launch_attribute[i].value .clusterDim .z =
583- launchPropList[i].value .clusterDim [0 ];
613+ attr.value .clusterDim .x = launchPropList[i].value .clusterDim [2 ];
614+ attr.value .clusterDim .y = launchPropList[i].value .clusterDim [1 ];
615+ attr.value .clusterDim .z = launchPropList[i].value .clusterDim [0 ];
584616 } else if (workDim == 2 ) {
585- launch_attribute[i].value .clusterDim .x =
586- launchPropList[i].value .clusterDim [1 ];
587- launch_attribute[i].value .clusterDim .y =
588- launchPropList[i].value .clusterDim [0 ];
589- launch_attribute[i].value .clusterDim .z =
590- launchPropList[i].value .clusterDim [2 ];
617+ attr.value .clusterDim .x = launchPropList[i].value .clusterDim [1 ];
618+ attr.value .clusterDim .y = launchPropList[i].value .clusterDim [0 ];
619+ attr.value .clusterDim .z = launchPropList[i].value .clusterDim [2 ];
591620 } else {
592- launch_attribute[i].value .clusterDim .x =
593- launchPropList[i].value .clusterDim [0 ];
594- launch_attribute[i].value .clusterDim .y =
595- launchPropList[i].value .clusterDim [1 ];
596- launch_attribute[i].value .clusterDim .z =
597- launchPropList[i].value .clusterDim [2 ];
621+ attr.value .clusterDim .x = launchPropList[i].value .clusterDim [0 ];
622+ attr.value .clusterDim .y = launchPropList[i].value .clusterDim [1 ];
623+ attr.value .clusterDim .z = launchPropList[i].value .clusterDim [2 ];
598624 }
599625
600626 UR_CHECK_ERROR (cuFuncSetAttribute (
@@ -603,9 +629,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
603629 break ;
604630 }
605631 case UR_EXP_LAUNCH_PROPERTY_ID_COOPERATIVE: {
606- launch_attribute[i].id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE;
607- launch_attribute[i].value .cooperative =
608- launchPropList[i].value .cooperative ;
632+ auto &attr = launch_attribute.emplace_back ();
633+ attr.id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE;
634+ attr.value .cooperative = launchPropList[i].value .cooperative ;
635+ break ;
636+ }
637+ case UR_EXP_LAUNCH_PROPERTY_ID_WORK_GROUP_MEMORY: {
609638 break ;
610639 }
611640 default : {
@@ -618,8 +647,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
618647 // using the standard UR_CHECK_ERROR
619648 if (ur_result_t Ret =
620649 setKernelParams (hQueue->getContext (), hQueue->Device , workDim,
621- nullptr , pGlobalWorkSize, pLocalWorkSize, hKernel ,
622- CuFunc, ThreadsPerBlock, BlocksPerGrid);
650+ pGlobalWorkOffset , pGlobalWorkSize, pLocalWorkSize,
651+ hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
623652 Ret != UR_RESULT_SUCCESS)
624653 return Ret;
625654
@@ -667,7 +696,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
667696 launch_config.sharedMemBytes = LocalSize;
668697 launch_config.hStream = CuStream;
669698 launch_config.attrs = &launch_attribute[0 ];
670- launch_config.numAttrs = numPropsInLaunchPropList ;
699+ launch_config.numAttrs = launch_attribute. size () ;
671700
672701 UR_CHECK_ERROR (cuLaunchKernelEx (&launch_config, CuFunc,
673702 const_cast <void **>(ArgIndices.data ()),
0 commit comments