@@ -978,15 +978,57 @@ ur_result_t ur_queue_immediate_in_order_t::enqueueCooperativeKernelLaunchExp(
978978 const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
979979 const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
980980 const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
981- std::ignore = hKernel;
982- std::ignore = workDim;
983- std::ignore = pGlobalWorkOffset;
984- std::ignore = pGlobalWorkSize;
985- std::ignore = pLocalWorkSize;
986- std::ignore = numEventsInWaitList;
987- std::ignore = phEventWaitList;
988- std::ignore = phEvent;
989- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
981+ TRACK_SCOPE_LATENCY (
982+ " ur_queue_immediate_in_order_t::enqueueCooperativeKernelLaunchExp" );
983+
984+ UR_ASSERT (hKernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
985+ UR_ASSERT (hKernel->getProgramHandle (), UR_RESULT_ERROR_INVALID_NULL_POINTER);
986+
987+ UR_ASSERT (workDim > 0 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
988+ UR_ASSERT (workDim < 4 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
989+
990+ ze_kernel_handle_t hZeKernel = hKernel->getZeHandle (hDevice);
991+
992+ std::scoped_lock<ur_shared_mutex, ur_shared_mutex> Lock (this ->Mutex ,
993+ hKernel->Mutex );
994+
995+ ze_group_count_t zeThreadGroupDimensions{1 , 1 , 1 };
996+ uint32_t WG[3 ]{};
997+ UR_CALL (calculateKernelWorkDimensions (hZeKernel, hDevice,
998+ zeThreadGroupDimensions, WG, workDim,
999+ pGlobalWorkSize, pLocalWorkSize));
1000+
1001+ auto signalEvent = getSignalEvent (phEvent, UR_COMMAND_KERNEL_LAUNCH);
1002+
1003+ auto waitList = getWaitListView (phEventWaitList, numEventsInWaitList);
1004+
1005+ bool memoryMigrated = false ;
1006+ auto memoryMigrate = [&](void *src, void *dst, size_t size) {
1007+ ZE2UR_CALL_THROWS (zeCommandListAppendMemoryCopy,
1008+ (handler.commandList .get (), dst, src, size, nullptr ,
1009+ waitList.second , waitList.first ));
1010+ memoryMigrated = true ;
1011+ };
1012+
1013+ UR_CALL (hKernel->prepareForSubmission (hContext, hDevice, pGlobalWorkOffset,
1014+ workDim, WG[0 ], WG[1 ], WG[2 ],
1015+ memoryMigrate));
1016+
1017+ if (memoryMigrated) {
1018+ // If memory was migrated, we don't need to pass the wait list to
1019+ // the copy command again.
1020+ waitList.first = nullptr ;
1021+ waitList.second = 0 ;
1022+ }
1023+
1024+ TRACK_SCOPE_LATENCY (" ur_queue_immediate_in_order_t::"
1025+ " zeCommandListAppendLaunchCooperativeKernel" );
1026+ auto zeSignalEvent = signalEvent ? signalEvent->getZeEvent () : nullptr ;
1027+ ZE2UR_CALL (zeCommandListAppendLaunchCooperativeKernel,
1028+ (handler.commandList .get (), hZeKernel, &zeThreadGroupDimensions,
1029+ zeSignalEvent, waitList.second , waitList.first ));
1030+
1031+ return UR_RESULT_SUCCESS;
9901032}
9911033
9921034ur_result_t ur_queue_immediate_in_order_t::enqueueTimestampRecordingExp (
0 commit comments