@@ -39,13 +39,42 @@ void KernelPropertiesTy::cacheGroupParams(const int32_t NumTeamsIn,
3939 GroupCounts = KEnv.GroupCounts ;
4040}
4141
42+ Error L0KernelTy::readKernelProperties (L0ProgramTy &Program) {
43+ const auto &l0Device = L0DeviceTy::makeL0Device (Program.getDevice ());
44+ auto &KernelPR = getProperties ();
45+ ze_kernel_properties_t KP = {};
46+ KP.stype = ZE_STRUCTURE_TYPE_KERNEL_PROPERTIES;
47+ KP.pNext = nullptr ;
48+ ze_kernel_preferred_group_size_properties_t KPrefGRPSize = {};
49+ KPrefGRPSize.stype = ZE_STRUCTURE_TYPE_KERNEL_PREFERRED_GROUP_SIZE_PROPERTIES;
50+ KPrefGRPSize.pNext = nullptr ;
51+ if (l0Device.getDriverAPIVersion () >= ZE_API_VERSION_1_2)
52+ KP.pNext = &KPrefGRPSize;
53+
54+ CALL_ZE_RET_ERROR (zeKernelGetProperties, zeKernel, &KP);
55+ KernelPR.SIMDWidth = KP.maxSubgroupSize ;
56+ KernelPR.Width = KP.maxSubgroupSize ;
57+
58+ if (KP.pNext )
59+ KernelPR.Width = KPrefGRPSize.preferredMultiple ;
60+
61+ if (!l0Device.isDeviceArch (DeviceArchTy::DeviceArch_Gen)) {
62+ KernelPR.Width = (std::max)(KernelPR.Width , 2 * KernelPR.SIMDWidth );
63+ }
64+ KernelPR.MaxThreadGroupSize = KP.maxSubgroupSize * KP.maxNumSubgroups ;
65+ return Plugin::success ();
66+ }
67+
4268Error L0KernelTy::buildKernel (L0ProgramTy &Program) {
4369 const auto *KernelName = getName ();
4470
4571 auto Module = Program.findModuleFromKernelName (KernelName);
4672 ze_kernel_desc_t KernelDesc = {ZE_STRUCTURE_TYPE_KERNEL_DESC, nullptr , 0 ,
4773 KernelName};
4874 CALL_ZE_RET_ERROR (zeKernelCreate, Module, &KernelDesc, &zeKernel);
75+ if (auto Err = readKernelProperties (Program))
76+ return Err;
77+
4978 return Plugin::success ();
5079}
5180
@@ -314,7 +343,23 @@ static Error launchKernelWithCmdQueue(L0DeviceTy &l0Device,
314343}
315344
316345Error L0KernelTy::setKernelGroups (L0DeviceTy &l0Device, L0LaunchEnvTy &KEnv,
317- int32_t NumTeams, int32_t ThreadLimit) const {
346+ uint32_t NumThreads[3 ], uint32_t NumBlocks[3 ]) const {
347+
348+ if (KernelEnvironment.Configuration .ExecMode != OMP_TGT_EXEC_MODE_BARE) {
349+ // For non-bare mode, the groups are already set in the launch
350+ KEnv.GroupCounts = {NumBlocks[0 ], NumBlocks[1 ], NumBlocks[2 ]};
351+ CALL_ZE_RET_ERROR (zeKernelSetGroupSize, getZeKernel (), NumThreads[0 ],
352+ NumThreads[1 ], NumThreads[2 ]);
353+ return Plugin::success ();
354+ }
355+
356+ int32_t NumTeams = NumThreads[0 ];
357+ int32_t ThreadLimit = NumBlocks[0 ];
358+ if (NumTeams < 0 )
359+ NumTeams = 0 ;
360+ if (ThreadLimit < 0 )
361+ ThreadLimit = 0 ;
362+
318363 uint32_t GroupSizes[3 ];
319364 auto DeviceId = l0Device.getDeviceId ();
320365 auto &KernelPR = KEnv.KernelPR ;
@@ -374,12 +419,6 @@ Error L0KernelTy::launchImpl(GenericDeviceTy &GenericDevice,
374419 auto zeKernel = getZeKernel ();
375420 auto DeviceId = l0Device.getDeviceId ();
376421 int32_t NumArgs = KernelArgs.NumArgs ;
377- int32_t NumTeams = NumThreads[0 ];
378- int32_t ThreadLimit = NumBlocks[0 ];
379- if (NumTeams < 0 )
380- NumTeams = 0 ;
381- if (ThreadLimit < 0 )
382- ThreadLimit = 0 ;
383422 INFO (OMP_INFOTYPE_PLUGIN_KERNEL, DeviceId, " Launching kernel " DPxMOD " ...\n " ,
384423 DPxPTR (zeKernel));
385424
@@ -404,7 +443,7 @@ Error L0KernelTy::launchImpl(GenericDeviceTy &GenericDevice,
404443 // Protect from kernel preparation to submission as kernels are shared.
405444 KernelPR.Mtx .lock ();
406445
407- if (auto Err = setKernelGroups (l0Device, KEnv, NumTeams, ThreadLimit ))
446+ if (auto Err = setKernelGroups (l0Device, KEnv, NumThreads, NumBlocks ))
408447 return Err;
409448
410449 // Set kernel arguments
0 commit comments