@@ -971,15 +971,57 @@ ur_result_t ur_queue_immediate_in_order_t::enqueueCooperativeKernelLaunchExp(
971971 const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
972972 const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
973973 const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
974- std::ignore = hKernel;
975- std::ignore = workDim;
976- std::ignore = pGlobalWorkOffset;
977- std::ignore = pGlobalWorkSize;
978- std::ignore = pLocalWorkSize;
979- std::ignore = numEventsInWaitList;
980- std::ignore = phEventWaitList;
981- std::ignore = phEvent;
982- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
974+ TRACK_SCOPE_LATENCY (
975+ " ur_queue_immediate_in_order_t::enqueueCooperativeKernelLaunchExp" );
976+
977+ UR_ASSERT (hKernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
978+ UR_ASSERT (hKernel->getProgramHandle (), UR_RESULT_ERROR_INVALID_NULL_POINTER);
979+
980+ UR_ASSERT (workDim > 0 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
981+ UR_ASSERT (workDim < 4 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
982+
983+ ze_kernel_handle_t hZeKernel = hKernel->getZeHandle (hDevice);
984+
985+ std::scoped_lock<ur_shared_mutex, ur_shared_mutex> Lock (this ->Mutex ,
986+ hKernel->Mutex );
987+
988+ ze_group_count_t zeThreadGroupDimensions{1 , 1 , 1 };
989+ uint32_t WG[3 ]{};
990+ UR_CALL (calculateKernelWorkDimensions (hZeKernel, hDevice,
991+ zeThreadGroupDimensions, WG, workDim,
992+ pGlobalWorkSize, pLocalWorkSize));
993+
994+ auto signalEvent = getSignalEvent (phEvent, UR_COMMAND_KERNEL_LAUNCH);
995+
996+ auto waitList = getWaitListView (phEventWaitList, numEventsInWaitList);
997+
998+ bool memoryMigrated = false ;
999+ auto memoryMigrate = [&](void *src, void *dst, size_t size) {
1000+ ZE2UR_CALL_THROWS (zeCommandListAppendMemoryCopy,
1001+ (handler.commandList .get (), dst, src, size, nullptr ,
1002+ waitList.second , waitList.first ));
1003+ memoryMigrated = true ;
1004+ };
1005+
1006+ UR_CALL (hKernel->prepareForSubmission (hContext, hDevice, pGlobalWorkOffset,
1007+ workDim, WG[0 ], WG[1 ], WG[2 ],
1008+ memoryMigrate));
1009+
1010+ if (memoryMigrated) {
1011+ // If memory was migrated, we don't need to pass the wait list to
1012+ // the copy command again.
1013+ waitList.first = nullptr ;
1014+ waitList.second = 0 ;
1015+ }
1016+
1017+ TRACK_SCOPE_LATENCY (" ur_queue_immediate_in_order_t::"
1018+ " zeCommandListAppendLaunchCooperativeKernel" );
1019+ auto zeSignalEvent = signalEvent ? signalEvent->getZeEvent () : nullptr ;
1020+ ZE2UR_CALL (zeCommandListAppendLaunchCooperativeKernel,
1021+ (handler.commandList .get (), hZeKernel, &zeThreadGroupDimensions,
1022+ zeSignalEvent, waitList.second , waitList.first ));
1023+
1024+ return UR_RESULT_SUCCESS;
9831025}
9841026
9851027ur_result_t ur_queue_immediate_in_order_t::enqueueTimestampRecordingExp (
0 commit comments