Skip to content

Commit 022f9e6

Browse files
Revert "refactor: change encoder for thread group over dispatch 2/n"
This reverts commit 0466317. Signed-off-by: Compute-Runtime-Validation <[email protected]>
1 parent 7f81179 commit 022f9e6

File tree

2 files changed

+19
-15
lines changed

2 files changed

+19
-15
lines changed

shared/source/command_container/command_encoder_xehp_and_later.inl

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,23 +1143,26 @@ void EncodeDispatchKernel<GfxFamily>::encodeThreadGroupDispatch(InterfaceDescrip
11431143
threadsPerXeCore /= 2;
11441144
}
11451145
auto tgDispatchSizeSelected = 8;
1146+
uint32_t numberOfThreadsInThreadGroup = interfaceDescriptor.getNumberOfThreadsInGpgpuThreadGroup();
11461147

1147-
if (threadGroupDimensions[0] > 1 && (threadGroupDimensions[1] > 1 || threadGroupDimensions[2] > 1)) {
1148-
while (threadGroupDimensions[0] % tgDispatchSizeSelected != 0) {
1148+
if (walkerCmd.getThreadGroupIdXDimension() > 1 && (walkerCmd.getThreadGroupIdYDimension() > 1 || walkerCmd.getThreadGroupIdZDimension() > 1)) {
1149+
while (walkerCmd.getThreadGroupIdXDimension() % tgDispatchSizeSelected != 0) {
11491150
tgDispatchSizeSelected /= 2;
11501151
}
1151-
} else if (threadGroupDimensions[1] > 1 && threadGroupDimensions[2] > 1) {
1152-
while (threadGroupDimensions[1] % tgDispatchSizeSelected != 0) {
1152+
} else if (walkerCmd.getThreadGroupIdYDimension() > 1 && walkerCmd.getThreadGroupIdZDimension() > 1) {
1153+
while (walkerCmd.getThreadGroupIdYDimension() % tgDispatchSizeSelected != 0) {
11531154
tgDispatchSizeSelected /= 2;
11541155
}
11551156
}
11561157

1158+
auto workgroupCount = walkerCmd.getThreadGroupIdXDimension() * walkerCmd.getThreadGroupIdYDimension() * walkerCmd.getThreadGroupIdZDimension();
1159+
11571160
// make sure we fit all xe core
1158-
while (threadGroupCount / tgDispatchSizeSelected < hwInfo.gtSystemInfo.MaxSubSlicesSupported * tileCount && tgDispatchSizeSelected > 1) {
1161+
while (workgroupCount / tgDispatchSizeSelected < hwInfo.gtSystemInfo.MaxSubSlicesSupported * tileCount && tgDispatchSizeSelected > 1) {
11591162
tgDispatchSizeSelected /= 2;
11601163
}
11611164

1162-
auto threadCountPerGrouping = tgDispatchSizeSelected * threadsPerThreadGroup;
1165+
auto threadCountPerGrouping = tgDispatchSizeSelected * numberOfThreadsInThreadGroup;
11631166
// make sure we do not use more threads then present on each xe core
11641167
while (threadCountPerGrouping > threadsPerXeCore && tgDispatchSizeSelected > 1) {
11651168
tgDispatchSizeSelected /= 2;
@@ -1184,25 +1187,26 @@ void EncodeDispatchKernel<GfxFamily>::encodeThreadGroupDispatch(InterfaceDescrip
11841187
uint32_t availableThreadCount = gfxCoreHelper.calculateAvailableThreadCount(hwInfo, grfCount);
11851188
availableThreadCount *= tileCount;
11861189

1187-
uint32_t dispatchedTotalThreadCount = threadsPerThreadGroup * threadGroupCount;
1188-
UNRECOVERABLE_IF(threadsPerThreadGroup == 0u);
1190+
uint32_t numberOfThreadsInThreadGroup = interfaceDescriptor.getNumberOfThreadsInGpgpuThreadGroup();
1191+
uint32_t dispatchedTotalThreadCount = numberOfThreadsInThreadGroup * threadGroupCount;
1192+
UNRECOVERABLE_IF(numberOfThreadsInThreadGroup == 0u);
11891193
auto tgDispatchSizeSelected = 1u;
11901194

11911195
if (dispatchedTotalThreadCount <= availableThreadCount) {
11921196
tgDispatchSizeSelected = 1;
1193-
} else if (threadsPerThreadGroup <= maxThreadsInTGForTGDispatchSize8) {
1197+
} else if (numberOfThreadsInThreadGroup <= maxThreadsInTGForTGDispatchSize8) {
11941198
tgDispatchSizeSelected = 8;
1195-
} else if (threadsPerThreadGroup <= maxThreadsInTGForTGDispatchSize4) {
1199+
} else if (numberOfThreadsInThreadGroup <= maxThreadsInTGForTGDispatchSize4) {
11961200
tgDispatchSizeSelected = 4;
11971201
} else {
11981202
tgDispatchSizeSelected = 2;
11991203
}
1200-
if (threadGroupDimensions[0] > 1 && (threadGroupDimensions[1] > 1 || threadGroupDimensions[2] > 1)) {
1201-
while (threadGroupDimensions[0] % tgDispatchSizeSelected != 0) {
1204+
if (walkerCmd.getThreadGroupIdXDimension() > 1 && (walkerCmd.getThreadGroupIdYDimension() > 1 || walkerCmd.getThreadGroupIdZDimension() > 1)) {
1205+
while (walkerCmd.getThreadGroupIdXDimension() % tgDispatchSizeSelected != 0) {
12021206
tgDispatchSizeSelected /= 2;
12031207
}
1204-
} else if (threadGroupDimensions[1] > 1 && threadGroupDimensions[2] > 1) {
1205-
while (threadGroupDimensions[1] % tgDispatchSizeSelected != 0) {
1208+
} else if (walkerCmd.getThreadGroupIdYDimension() > 1 && walkerCmd.getThreadGroupIdZDimension() > 1) {
1209+
while (walkerCmd.getThreadGroupIdYDimension() % tgDispatchSizeSelected != 0) {
12061210
tgDispatchSizeSelected /= 2;
12071211
}
12081212
}

shared/source/xe_hpg_core/command_encoder_xe_hpg_core.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ void EncodeDispatchKernel<Family>::encodeThreadGroupDispatch(InterfaceDescriptor
3333
const uint32_t *threadGroupDimensions, const uint32_t threadGroupCount, const uint32_t grfCount, const uint32_t threadsPerThreadGroup, WalkerType &walkerCmd) {
3434
const auto &productHelper = device.getProductHelper();
3535
if (productHelper.isDisableOverdispatchAvailable(hwInfo)) {
36-
if (threadsPerThreadGroup == 1) {
36+
if (interfaceDescriptor.getNumberOfThreadsInGpgpuThreadGroup() == 1) {
3737
interfaceDescriptor.setThreadGroupDispatchSize(static_cast<INTERFACE_DESCRIPTOR_DATA::THREAD_GROUP_DISPATCH_SIZE>(2u));
3838
} else {
3939
interfaceDescriptor.setThreadGroupDispatchSize(static_cast<INTERFACE_DESCRIPTOR_DATA::THREAD_GROUP_DISPATCH_SIZE>(3u));

0 commit comments

Comments
 (0)