@@ -422,11 +422,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWait(
422422 phEventWaitList, phEvent);
423423}
424424
425- UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch (
426- ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
427- const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
428- const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
429- const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
425+ static ur_result_t
426+ enqueueKernelLaunch (ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel,
427+ uint32_t workDim, const size_t *pGlobalWorkOffset,
428+ const size_t *pGlobalWorkSize, const size_t *pLocalWorkSize,
429+ uint32_t numEventsInWaitList,
430+ const ur_event_handle_t *phEventWaitList,
431+ ur_event_handle_t *phEvent, size_t WorkGroupMemory) {
430432 // Preconditions
431433 UR_ASSERT (hQueue->getDevice () == hKernel->getProgram ()->getDevice (),
432434 UR_RESULT_ERROR_INVALID_KERNEL);
@@ -444,6 +446,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
444446 size_t ThreadsPerBlock[3 ] = {32u , 1u , 1u };
445447 size_t BlocksPerGrid[3 ] = {1u , 1u , 1u };
446448
449+ // Set work group memory so we can compute the whole memory requirement
450+ if (WorkGroupMemory)
451+ hKernel->setWorkGroupMemory (WorkGroupMemory);
447452 uint32_t LocalSize = hKernel->getLocalSize ();
448453 CUfunction CuFunc = hKernel->get ();
449454
@@ -503,6 +508,17 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
503508 return UR_RESULT_SUCCESS;
504509}
505510
511+ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch (
512+ ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
513+ const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
514+ const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
515+ const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
516+ return enqueueKernelLaunch (hQueue, hKernel, workDim, pGlobalWorkOffset,
517+ pGlobalWorkSize, pLocalWorkSize,
518+ numEventsInWaitList, phEventWaitList, phEvent,
519+ /* WorkGroupMemory=*/ 0 );
520+ }
521+
506522UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp (
507523 ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
508524 const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
@@ -513,8 +529,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
513529 coop_prop.id = UR_EXP_LAUNCH_PROPERTY_ID_COOPERATIVE;
514530 coop_prop.value .cooperative = 1 ;
515531 return urEnqueueKernelLaunchCustomExp (
516- hQueue, hKernel, workDim, pGlobalWorkSize, pLocalWorkSize, 1 ,
517- &coop_prop, numEventsInWaitList, phEventWaitList, phEvent);
532+ hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize,
533+ pLocalWorkSize, 1 , &coop_prop, numEventsInWaitList, phEventWaitList,
534+ phEvent);
518535 }
519536 return urEnqueueKernelLaunch (hQueue, hKernel, workDim, pGlobalWorkOffset,
520537 pGlobalWorkSize, pLocalWorkSize,
@@ -523,16 +540,29 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
523540
524541UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp (
525542 ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
526- const size_t *pGlobalWorkSize , const size_t *pLocalWorkSize ,
527- uint32_t numPropsInLaunchPropList,
543+ const size_t *pGlobalWorkOffset , const size_t *pGlobalWorkSize ,
544+ const size_t *pLocalWorkSize, uint32_t numPropsInLaunchPropList,
528545 const ur_exp_launch_property_t *launchPropList,
529546 uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
530547 ur_event_handle_t *phEvent) {
531548
532- if (numPropsInLaunchPropList == 0 ) {
533- urEnqueueKernelLaunch (hQueue, hKernel, workDim, nullptr , pGlobalWorkSize,
534- pLocalWorkSize, numEventsInWaitList, phEventWaitList,
535- phEvent);
549+ size_t WorkGroupMemory = [&]() -> size_t {
550+ const ur_exp_launch_property_t *WorkGroupMemoryProp = std::find_if (
551+ launchPropList, launchPropList + numPropsInLaunchPropList,
552+ [](const ur_exp_launch_property_t &Prop) {
553+ return Prop.id == UR_EXP_LAUNCH_PROPERTY_ID_WORK_GROUP_MEMORY;
554+ });
555+ if (WorkGroupMemoryProp != launchPropList + numPropsInLaunchPropList)
556+ return WorkGroupMemoryProp->value .workgroup_mem_size ;
557+ return 0 ;
558+ }();
559+
560+ if (numPropsInLaunchPropList == 0 ||
561+ (WorkGroupMemory && numPropsInLaunchPropList == 1 )) {
562+ return enqueueKernelLaunch (hQueue, hKernel, workDim, pGlobalWorkOffset,
563+ pGlobalWorkSize, pLocalWorkSize,
564+ numEventsInWaitList, phEventWaitList, phEvent,
565+ WorkGroupMemory);
536566 }
537567#if CUDA_VERSION >= 11080
538568 // Preconditions
@@ -545,7 +575,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
545575 return UR_RESULT_ERROR_INVALID_NULL_POINTER;
546576 }
547577
548- std::vector<CUlaunchAttribute> launch_attribute (numPropsInLaunchPropList);
578+ std::vector<CUlaunchAttribute> launch_attribute;
579+ launch_attribute.reserve (numPropsInLaunchPropList);
549580
550581 // Early exit for zero size kernel
551582 if (*pGlobalWorkSize == 0 ) {
@@ -558,40 +589,35 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
558589 size_t ThreadsPerBlock[3 ] = {32u , 1u , 1u };
559590 size_t BlocksPerGrid[3 ] = {1u , 1u , 1u };
560591
592+ // Set work group memory so we can compute the whole memory requirement
593+ if (WorkGroupMemory)
594+ hKernel->setWorkGroupMemory (WorkGroupMemory);
561595 uint32_t LocalSize = hKernel->getLocalSize ();
562596 CUfunction CuFunc = hKernel->get ();
563597
564598 for (uint32_t i = 0 ; i < numPropsInLaunchPropList; i++) {
565599 switch (launchPropList[i].id ) {
566600 case UR_EXP_LAUNCH_PROPERTY_ID_IGNORE: {
567- launch_attribute[i].id = CU_LAUNCH_ATTRIBUTE_IGNORE;
601+ auto &attr = launch_attribute.emplace_back ();
602+ attr.id = CU_LAUNCH_ATTRIBUTE_IGNORE;
568603 break ;
569604 }
570605 case UR_EXP_LAUNCH_PROPERTY_ID_CLUSTER_DIMENSION: {
571-
572- launch_attribute[i] .id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
606+ auto &attr = launch_attribute. emplace_back ();
607+ attr .id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
573608 // Note that cuda orders from right to left wrt SYCL dimensional order.
574609 if (workDim == 3 ) {
575- launch_attribute[i].value .clusterDim .x =
576- launchPropList[i].value .clusterDim [2 ];
577- launch_attribute[i].value .clusterDim .y =
578- launchPropList[i].value .clusterDim [1 ];
579- launch_attribute[i].value .clusterDim .z =
580- launchPropList[i].value .clusterDim [0 ];
610+ attr.value .clusterDim .x = launchPropList[i].value .clusterDim [2 ];
611+ attr.value .clusterDim .y = launchPropList[i].value .clusterDim [1 ];
612+ attr.value .clusterDim .z = launchPropList[i].value .clusterDim [0 ];
581613 } else if (workDim == 2 ) {
582- launch_attribute[i].value .clusterDim .x =
583- launchPropList[i].value .clusterDim [1 ];
584- launch_attribute[i].value .clusterDim .y =
585- launchPropList[i].value .clusterDim [0 ];
586- launch_attribute[i].value .clusterDim .z =
587- launchPropList[i].value .clusterDim [2 ];
614+ attr.value .clusterDim .x = launchPropList[i].value .clusterDim [1 ];
615+ attr.value .clusterDim .y = launchPropList[i].value .clusterDim [0 ];
616+ attr.value .clusterDim .z = launchPropList[i].value .clusterDim [2 ];
588617 } else {
589- launch_attribute[i].value .clusterDim .x =
590- launchPropList[i].value .clusterDim [0 ];
591- launch_attribute[i].value .clusterDim .y =
592- launchPropList[i].value .clusterDim [1 ];
593- launch_attribute[i].value .clusterDim .z =
594- launchPropList[i].value .clusterDim [2 ];
618+ attr.value .clusterDim .x = launchPropList[i].value .clusterDim [0 ];
619+ attr.value .clusterDim .y = launchPropList[i].value .clusterDim [1 ];
620+ attr.value .clusterDim .z = launchPropList[i].value .clusterDim [2 ];
595621 }
596622
597623 UR_CHECK_ERROR (cuFuncSetAttribute (
@@ -600,9 +626,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
600626 break ;
601627 }
602628 case UR_EXP_LAUNCH_PROPERTY_ID_COOPERATIVE: {
603- launch_attribute[i].id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE;
604- launch_attribute[i].value .cooperative =
605- launchPropList[i].value .cooperative ;
629+ auto &attr = launch_attribute.emplace_back ();
630+ attr.id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE;
631+ attr.value .cooperative = launchPropList[i].value .cooperative ;
632+ break ;
633+ }
634+ case UR_EXP_LAUNCH_PROPERTY_ID_WORK_GROUP_MEMORY: {
606635 break ;
607636 }
608637 default : {
@@ -615,8 +644,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
615644 // using the standard UR_CHECK_ERROR
616645 if (ur_result_t Ret =
617646 setKernelParams (hQueue->getContext (), hQueue->Device , workDim,
618- nullptr , pGlobalWorkSize, pLocalWorkSize, hKernel ,
619- CuFunc, ThreadsPerBlock, BlocksPerGrid);
647+ pGlobalWorkOffset , pGlobalWorkSize, pLocalWorkSize,
648+ hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
620649 Ret != UR_RESULT_SUCCESS)
621650 return Ret;
622651
@@ -664,7 +693,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
664693 launch_config.sharedMemBytes = LocalSize;
665694 launch_config.hStream = CuStream;
666695 launch_config.attrs = &launch_attribute[0 ];
667- launch_config.numAttrs = numPropsInLaunchPropList ;
696+ launch_config.numAttrs = launch_attribute. size () ;
668697
669698 UR_CHECK_ERROR (cuLaunchKernelEx (&launch_config, CuFunc,
670699 const_cast <void **>(ArgIndices.data ()),
0 commit comments