@@ -484,6 +484,133 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
484
484
numEventsInWaitList, phEventWaitList, phEvent);
485
485
}
486
486
487
+ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp (
488
+ ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
489
+ const size_t *pGlobalWorkSize, const size_t *pLocalWorkSize,
490
+ uint32_t numAttrsInLaunchAttrList,
491
+ const ur_exp_launch_attribute_t *launchAttrList,
492
+ uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
493
+ ur_event_handle_t *phEvent) {
494
+
495
+ if (numAttrsInLaunchAttrList == 0 ) {
496
+ urEnqueueKernelLaunch (hQueue, hKernel, workDim, nullptr , pGlobalWorkSize,
497
+ pLocalWorkSize, numEventsInWaitList, phEventWaitList,
498
+ phEvent);
499
+ }
500
+
501
+ // Preconditions
502
+ UR_ASSERT (hQueue->getContext () == hKernel->getContext (),
503
+ UR_RESULT_ERROR_INVALID_KERNEL);
504
+ UR_ASSERT (workDim > 0 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
505
+ UR_ASSERT (workDim < 4 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
506
+
507
+ if (launchAttrList == NULL ) {
508
+ return UR_RESULT_ERROR_INVALID_NULL_POINTER;
509
+ }
510
+
511
+ std::vector<CUlaunchAttribute> launch_attribute (numAttrsInLaunchAttrList);
512
+ for (uint32_t i = 0 ; i < numAttrsInLaunchAttrList; i++) {
513
+ switch (launchAttrList[i].id ) {
514
+ case UR_EXP_LAUNCH_ATTRIBUTE_ID_IGNORE: {
515
+ launch_attribute[i].id = CU_LAUNCH_ATTRIBUTE_IGNORE;
516
+ break ;
517
+ }
518
+ case UR_EXP_LAUNCH_ATTRIBUTE_ID_CLUSTER_DIMENSION: {
519
+
520
+ launch_attribute[i].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
521
+ launch_attribute[i].value .clusterDim .x =
522
+ launchAttrList[i].value .clusterDim [0 ];
523
+ launch_attribute[i].value .clusterDim .y =
524
+ launchAttrList[i].value .clusterDim [1 ];
525
+ launch_attribute[i].value .clusterDim .z =
526
+ launchAttrList[i].value .clusterDim [2 ];
527
+ break ;
528
+ }
529
+ case UR_EXP_LAUNCH_ATTRIBUTE_ID_COOPERATIVE: {
530
+ launch_attribute[i].id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE;
531
+ launch_attribute[i].value .cooperative =
532
+ launchAttrList[i].value .cooperative ;
533
+ break ;
534
+ }
535
+ default : {
536
+ return UR_RESULT_ERROR_INVALID_ENUMERATION;
537
+ }
538
+ }
539
+ }
540
+
541
+ if (*pGlobalWorkSize == 0 ) {
542
+ return urEnqueueEventsWaitWithBarrier (hQueue, numEventsInWaitList,
543
+ phEventWaitList, phEvent);
544
+ }
545
+
546
+ // Set the number of threads per block to the number of threads per warp
547
+ // by default unless user has provided a better number
548
+ size_t ThreadsPerBlock[3 ] = {32u , 1u , 1u };
549
+ size_t BlocksPerGrid[3 ] = {1u , 1u , 1u };
550
+
551
+ uint32_t LocalSize = hKernel->getLocalSize ();
552
+ ur_result_t Result = UR_RESULT_SUCCESS;
553
+ CUfunction CuFunc = hKernel->get ();
554
+
555
+ Result = setKernelParams (hQueue->getContext (), hQueue->Device , workDim,
556
+ nullptr , pGlobalWorkSize, pLocalWorkSize, hKernel,
557
+ CuFunc, ThreadsPerBlock, BlocksPerGrid);
558
+ if (Result != UR_RESULT_SUCCESS) {
559
+ return Result;
560
+ }
561
+
562
+ try {
563
+ std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
564
+
565
+ uint32_t StreamToken;
566
+ ur_stream_guard_ Guard;
567
+ CUstream CuStream = hQueue->getNextComputeStream (
568
+ numEventsInWaitList, phEventWaitList, Guard, &StreamToken);
569
+
570
+ Result = enqueueEventsWait (hQueue, CuStream, numEventsInWaitList,
571
+ phEventWaitList);
572
+
573
+ if (phEvent) {
574
+ RetImplEvent =
575
+ std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
576
+ UR_COMMAND_KERNEL_LAUNCH, hQueue, CuStream, StreamToken));
577
+ UR_CHECK_ERROR (RetImplEvent->start ());
578
+ }
579
+
580
+ auto &ArgIndices = hKernel->getArgIndices ();
581
+
582
+ CUlaunchConfig launch_config;
583
+ launch_config.gridDimX = BlocksPerGrid[0 ];
584
+ launch_config.gridDimY = BlocksPerGrid[1 ];
585
+ launch_config.gridDimZ = BlocksPerGrid[2 ];
586
+ launch_config.blockDimX = ThreadsPerBlock[0 ];
587
+ launch_config.blockDimY = ThreadsPerBlock[1 ];
588
+ launch_config.blockDimZ = ThreadsPerBlock[2 ];
589
+
590
+ launch_config.sharedMemBytes = LocalSize;
591
+ launch_config.hStream = CuStream;
592
+ launch_config.attrs = &launch_attribute[0 ];
593
+ launch_config.numAttrs = numAttrsInLaunchAttrList;
594
+
595
+ UR_CHECK_ERROR (cuLaunchKernelEx (&launch_config, CuFunc,
596
+ const_cast <void **>(ArgIndices.data ()),
597
+ nullptr ));
598
+
599
+ if (LocalSize != 0 )
600
+ hKernel->clearLocalSize ();
601
+
602
+ if (phEvent) {
603
+ UR_CHECK_ERROR (RetImplEvent->record ());
604
+ *phEvent = RetImplEvent.release ();
605
+ }
606
+
607
+ } catch (ur_result_t Err) {
608
+ Result = Err;
609
+ }
610
+ return Result;
611
+ }
612
+
613
+
487
614
// / Set parameters for general 3D memory copy.
488
615
// / If the source and/or destination is on the device, SrcPtr and/or DstPtr
489
616
// / must be a pointer to a CUdeviceptr
0 commit comments