@@ -422,13 +422,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWait(
422422 phEventWaitList, phEvent);
423423}
424424
425- static ur_result_t
426- enqueueKernelLaunch (ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel,
427- uint32_t workDim, const size_t *pGlobalWorkOffset,
428- const size_t *pGlobalWorkSize, const size_t *pLocalWorkSize,
429- uint32_t numEventsInWaitList,
430- const ur_event_handle_t *phEventWaitList,
431- ur_event_handle_t *phEvent, size_t WorkGroupMemory) {
425+ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch (
426+ ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
427+ const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
428+ const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
429+ const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
432430 // Preconditions
433431 UR_ASSERT (hQueue->getDevice () == hKernel->getProgram ()->getDevice (),
434432 UR_RESULT_ERROR_INVALID_KERNEL);
@@ -446,9 +444,6 @@ enqueueKernelLaunch(ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel,
446444 size_t ThreadsPerBlock[3 ] = {32u , 1u , 1u };
447445 size_t BlocksPerGrid[3 ] = {1u , 1u , 1u };
448446
449- // Set work group memory so we can compute the whole memory requirement
450- if (WorkGroupMemory)
451- hKernel->setWorkGroupMemory (WorkGroupMemory);
452447 uint32_t LocalSize = hKernel->getLocalSize ();
453448 CUfunction CuFunc = hKernel->get ();
454449
@@ -511,17 +506,6 @@ enqueueKernelLaunch(ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel,
511506 return UR_RESULT_SUCCESS;
512507}
513508
514- UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch (
515- ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
516- const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
517- const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
518- const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
519- return enqueueKernelLaunch (hQueue, hKernel, workDim, pGlobalWorkOffset,
520- pGlobalWorkSize, pLocalWorkSize,
521- numEventsInWaitList, phEventWaitList, phEvent,
522- /* WorkGroupMemory=*/ 0 );
523- }
524-
525509UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp (
526510 ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
527511 const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
@@ -532,9 +516,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
532516 coop_prop.id = UR_EXP_LAUNCH_PROPERTY_ID_COOPERATIVE;
533517 coop_prop.value .cooperative = 1 ;
534518 return urEnqueueKernelLaunchCustomExp (
535- hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize,
536- pLocalWorkSize, 1 , &coop_prop, numEventsInWaitList, phEventWaitList,
537- phEvent);
519+ hQueue, hKernel, workDim, pGlobalWorkSize, pLocalWorkSize, 1 ,
520+ &coop_prop, numEventsInWaitList, phEventWaitList, phEvent);
538521 }
539522 return urEnqueueKernelLaunch (hQueue, hKernel, workDim, pGlobalWorkOffset,
540523 pGlobalWorkSize, pLocalWorkSize,
@@ -543,29 +526,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
543526
544527UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp (
545528 ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
546- const size_t *pGlobalWorkOffset , const size_t *pGlobalWorkSize ,
547- const size_t *pLocalWorkSize, uint32_t numPropsInLaunchPropList,
529+ const size_t *pGlobalWorkSize , const size_t *pLocalWorkSize ,
530+ uint32_t numPropsInLaunchPropList,
548531 const ur_exp_launch_property_t *launchPropList,
549532 uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
550533 ur_event_handle_t *phEvent) {
551534
552- size_t WorkGroupMemory = [&]() -> size_t {
553- const ur_exp_launch_property_t *WorkGroupMemoryProp = std::find_if (
554- launchPropList, launchPropList + numPropsInLaunchPropList,
555- [](const ur_exp_launch_property_t &Prop) {
556- return Prop.id == UR_EXP_LAUNCH_PROPERTY_ID_WORK_GROUP_MEMORY;
557- });
558- if (WorkGroupMemoryProp != launchPropList + numPropsInLaunchPropList)
559- return WorkGroupMemoryProp->value .workgroup_mem_size ;
560- return 0 ;
561- }();
562-
563- if (numPropsInLaunchPropList == 0 ||
564- (WorkGroupMemory && numPropsInLaunchPropList == 1 )) {
565- return enqueueKernelLaunch (hQueue, hKernel, workDim, pGlobalWorkOffset,
566- pGlobalWorkSize, pLocalWorkSize,
567- numEventsInWaitList, phEventWaitList, phEvent,
568- WorkGroupMemory);
535+ if (numPropsInLaunchPropList == 0 ) {
536+ urEnqueueKernelLaunch (hQueue, hKernel, workDim, nullptr , pGlobalWorkSize,
537+ pLocalWorkSize, numEventsInWaitList, phEventWaitList,
538+ phEvent);
569539 }
570540#if CUDA_VERSION >= 11080
571541 // Preconditions
@@ -578,8 +548,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
578548 return UR_RESULT_ERROR_INVALID_NULL_POINTER;
579549 }
580550
581- std::vector<CUlaunchAttribute> launch_attribute;
582- launch_attribute.reserve (numPropsInLaunchPropList);
551+ std::vector<CUlaunchAttribute> launch_attribute (numPropsInLaunchPropList);
583552
584553 // Early exit for zero size kernel
585554 if (*pGlobalWorkSize == 0 ) {
@@ -592,35 +561,40 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
592561 size_t ThreadsPerBlock[3 ] = {32u , 1u , 1u };
593562 size_t BlocksPerGrid[3 ] = {1u , 1u , 1u };
594563
595- // Set work group memory so we can compute the whole memory requirement
596- if (WorkGroupMemory)
597- hKernel->setWorkGroupMemory (WorkGroupMemory);
598564 uint32_t LocalSize = hKernel->getLocalSize ();
599565 CUfunction CuFunc = hKernel->get ();
600566
601567 for (uint32_t i = 0 ; i < numPropsInLaunchPropList; i++) {
602568 switch (launchPropList[i].id ) {
603569 case UR_EXP_LAUNCH_PROPERTY_ID_IGNORE: {
604- auto &attr = launch_attribute.emplace_back ();
605- attr.id = CU_LAUNCH_ATTRIBUTE_IGNORE;
570+ launch_attribute[i].id = CU_LAUNCH_ATTRIBUTE_IGNORE;
606571 break ;
607572 }
608573 case UR_EXP_LAUNCH_PROPERTY_ID_CLUSTER_DIMENSION: {
609- auto &attr = launch_attribute. emplace_back ();
610- attr .id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
574+
575+ launch_attribute[i] .id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
611576 // Note that cuda orders from right to left wrt SYCL dimensional order.
612577 if (workDim == 3 ) {
613- attr.value .clusterDim .x = launchPropList[i].value .clusterDim [2 ];
614- attr.value .clusterDim .y = launchPropList[i].value .clusterDim [1 ];
615- attr.value .clusterDim .z = launchPropList[i].value .clusterDim [0 ];
578+ launch_attribute[i].value .clusterDim .x =
579+ launchPropList[i].value .clusterDim [2 ];
580+ launch_attribute[i].value .clusterDim .y =
581+ launchPropList[i].value .clusterDim [1 ];
582+ launch_attribute[i].value .clusterDim .z =
583+ launchPropList[i].value .clusterDim [0 ];
616584 } else if (workDim == 2 ) {
617- attr.value .clusterDim .x = launchPropList[i].value .clusterDim [1 ];
618- attr.value .clusterDim .y = launchPropList[i].value .clusterDim [0 ];
619- attr.value .clusterDim .z = launchPropList[i].value .clusterDim [2 ];
585+ launch_attribute[i].value .clusterDim .x =
586+ launchPropList[i].value .clusterDim [1 ];
587+ launch_attribute[i].value .clusterDim .y =
588+ launchPropList[i].value .clusterDim [0 ];
589+ launch_attribute[i].value .clusterDim .z =
590+ launchPropList[i].value .clusterDim [2 ];
620591 } else {
621- attr.value .clusterDim .x = launchPropList[i].value .clusterDim [0 ];
622- attr.value .clusterDim .y = launchPropList[i].value .clusterDim [1 ];
623- attr.value .clusterDim .z = launchPropList[i].value .clusterDim [2 ];
592+ launch_attribute[i].value .clusterDim .x =
593+ launchPropList[i].value .clusterDim [0 ];
594+ launch_attribute[i].value .clusterDim .y =
595+ launchPropList[i].value .clusterDim [1 ];
596+ launch_attribute[i].value .clusterDim .z =
597+ launchPropList[i].value .clusterDim [2 ];
624598 }
625599
626600 UR_CHECK_ERROR (cuFuncSetAttribute (
@@ -629,12 +603,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
629603 break ;
630604 }
631605 case UR_EXP_LAUNCH_PROPERTY_ID_COOPERATIVE: {
632- auto &attr = launch_attribute.emplace_back ();
633- attr.id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE;
634- attr.value .cooperative = launchPropList[i].value .cooperative ;
635- break ;
636- }
637- case UR_EXP_LAUNCH_PROPERTY_ID_WORK_GROUP_MEMORY: {
606+ launch_attribute[i].id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE;
607+ launch_attribute[i].value .cooperative =
608+ launchPropList[i].value .cooperative ;
638609 break ;
639610 }
640611 default : {
@@ -647,8 +618,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
647618 // using the standard UR_CHECK_ERROR
648619 if (ur_result_t Ret =
649620 setKernelParams (hQueue->getContext (), hQueue->Device , workDim,
650- pGlobalWorkOffset , pGlobalWorkSize, pLocalWorkSize,
651- hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
621+ nullptr , pGlobalWorkSize, pLocalWorkSize, hKernel ,
622+ CuFunc, ThreadsPerBlock, BlocksPerGrid);
652623 Ret != UR_RESULT_SUCCESS)
653624 return Ret;
654625
@@ -696,7 +667,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
696667 launch_config.sharedMemBytes = LocalSize;
697668 launch_config.hStream = CuStream;
698669 launch_config.attrs = &launch_attribute[0 ];
699- launch_config.numAttrs = launch_attribute. size () ;
670+ launch_config.numAttrs = numPropsInLaunchPropList ;
700671
701672 UR_CHECK_ERROR (cuLaunchKernelEx (&launch_config, CuFunc,
702673 const_cast <void **>(ArgIndices.data ()),
0 commit comments