@@ -1143,23 +1143,26 @@ void EncodeDispatchKernel<GfxFamily>::encodeThreadGroupDispatch(InterfaceDescrip
11431143 threadsPerXeCore /= 2 ;
11441144 }
11451145 auto tgDispatchSizeSelected = 8 ;
1146+ uint32_t numberOfThreadsInThreadGroup = interfaceDescriptor.getNumberOfThreadsInGpgpuThreadGroup ();
11461147
1147- if (threadGroupDimensions[ 0 ] > 1 && (threadGroupDimensions[ 1 ] > 1 || threadGroupDimensions[ 2 ] > 1 )) {
1148- while (threadGroupDimensions[ 0 ] % tgDispatchSizeSelected != 0 ) {
1148+ if (walkerCmd. getThreadGroupIdXDimension () > 1 && (walkerCmd. getThreadGroupIdYDimension () > 1 || walkerCmd. getThreadGroupIdZDimension () > 1 )) {
1149+ while (walkerCmd. getThreadGroupIdXDimension () % tgDispatchSizeSelected != 0 ) {
11491150 tgDispatchSizeSelected /= 2 ;
11501151 }
1151- } else if (threadGroupDimensions[ 1 ] > 1 && threadGroupDimensions[ 2 ] > 1 ) {
1152- while (threadGroupDimensions[ 1 ] % tgDispatchSizeSelected != 0 ) {
1152+ } else if (walkerCmd. getThreadGroupIdYDimension () > 1 && walkerCmd. getThreadGroupIdZDimension () > 1 ) {
1153+ while (walkerCmd. getThreadGroupIdYDimension () % tgDispatchSizeSelected != 0 ) {
11531154 tgDispatchSizeSelected /= 2 ;
11541155 }
11551156 }
11561157
1158+ auto workgroupCount = walkerCmd.getThreadGroupIdXDimension () * walkerCmd.getThreadGroupIdYDimension () * walkerCmd.getThreadGroupIdZDimension ();
1159+
11571160 // make sure we fit all xe core
1158- while (threadGroupCount / tgDispatchSizeSelected < hwInfo.gtSystemInfo .MaxSubSlicesSupported * tileCount && tgDispatchSizeSelected > 1 ) {
1161+ while (workgroupCount / tgDispatchSizeSelected < hwInfo.gtSystemInfo .MaxSubSlicesSupported * tileCount && tgDispatchSizeSelected > 1 ) {
11591162 tgDispatchSizeSelected /= 2 ;
11601163 }
11611164
1162- auto threadCountPerGrouping = tgDispatchSizeSelected * threadsPerThreadGroup ;
1165+ auto threadCountPerGrouping = tgDispatchSizeSelected * numberOfThreadsInThreadGroup ;
11631166 // make sure we do not use more threads then present on each xe core
11641167 while (threadCountPerGrouping > threadsPerXeCore && tgDispatchSizeSelected > 1 ) {
11651168 tgDispatchSizeSelected /= 2 ;
@@ -1184,25 +1187,26 @@ void EncodeDispatchKernel<GfxFamily>::encodeThreadGroupDispatch(InterfaceDescrip
11841187 uint32_t availableThreadCount = gfxCoreHelper.calculateAvailableThreadCount (hwInfo, grfCount);
11851188 availableThreadCount *= tileCount;
11861189
1187- uint32_t dispatchedTotalThreadCount = threadsPerThreadGroup * threadGroupCount;
1188- UNRECOVERABLE_IF (threadsPerThreadGroup == 0u );
1190+ uint32_t numberOfThreadsInThreadGroup = interfaceDescriptor.getNumberOfThreadsInGpgpuThreadGroup ();
1191+ uint32_t dispatchedTotalThreadCount = numberOfThreadsInThreadGroup * threadGroupCount;
1192+ UNRECOVERABLE_IF (numberOfThreadsInThreadGroup == 0u );
11891193 auto tgDispatchSizeSelected = 1u ;
11901194
11911195 if (dispatchedTotalThreadCount <= availableThreadCount) {
11921196 tgDispatchSizeSelected = 1 ;
1193- } else if (threadsPerThreadGroup <= maxThreadsInTGForTGDispatchSize8) {
1197+ } else if (numberOfThreadsInThreadGroup <= maxThreadsInTGForTGDispatchSize8) {
11941198 tgDispatchSizeSelected = 8 ;
1195- } else if (threadsPerThreadGroup <= maxThreadsInTGForTGDispatchSize4) {
1199+ } else if (numberOfThreadsInThreadGroup <= maxThreadsInTGForTGDispatchSize4) {
11961200 tgDispatchSizeSelected = 4 ;
11971201 } else {
11981202 tgDispatchSizeSelected = 2 ;
11991203 }
1200- if (threadGroupDimensions[ 0 ] > 1 && (threadGroupDimensions[ 1 ] > 1 || threadGroupDimensions[ 2 ] > 1 )) {
1201- while (threadGroupDimensions[ 0 ] % tgDispatchSizeSelected != 0 ) {
1204+ if (walkerCmd. getThreadGroupIdXDimension () > 1 && (walkerCmd. getThreadGroupIdYDimension () > 1 || walkerCmd. getThreadGroupIdZDimension () > 1 )) {
1205+ while (walkerCmd. getThreadGroupIdXDimension () % tgDispatchSizeSelected != 0 ) {
12021206 tgDispatchSizeSelected /= 2 ;
12031207 }
1204- } else if (threadGroupDimensions[ 1 ] > 1 && threadGroupDimensions[ 2 ] > 1 ) {
1205- while (threadGroupDimensions[ 1 ] % tgDispatchSizeSelected != 0 ) {
1208+ } else if (walkerCmd. getThreadGroupIdYDimension () > 1 && walkerCmd. getThreadGroupIdZDimension () > 1 ) {
1209+ while (walkerCmd. getThreadGroupIdYDimension () % tgDispatchSizeSelected != 0 ) {
12061210 tgDispatchSizeSelected /= 2 ;
12071211 }
12081212 }
0 commit comments