Skip to content

Commit 8467c29

Browse files
committed
Add check for output layout
1 parent 00ffa57 commit 8467c29

File tree

1 file changed

+1
-30
lines changed

1 file changed

+1
-30
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -34,35 +34,6 @@ using namespace mlir;
3434

3535
namespace {
3636

37-
bool isDistributable(ArrayRef<int64_t> sgLayout, ArrayRef<int64_t> sgData,
38-
ArrayRef<int64_t> wgShape) {
39-
// Check rank consistency
40-
if (sgLayout.size() != sgData.size() || sgLayout.size() != wgShape.size())
41-
return false;
42-
43-
for (size_t i = 0; i < sgLayout.size(); ++i) {
44-
int64_t subgroupCount = sgLayout[i];
45-
int64_t subgroupData = sgData[i];
46-
int64_t shape = wgShape[i];
47-
48-
// Each subgroup must have positive data size
49-
if (subgroupData <= 0 || subgroupCount <= 0 || shape <= 0)
50-
return false;
51-
52-
// Total data assigned to all subgroups in this dimension
53-
int64_t totalSubgroupData = subgroupCount * subgroupData;
54-
55-
// Subgroups must not collectively exceed the shape
56-
if (totalSubgroupData > shape)
57-
return false;
58-
59-
// Each subgroup's data must evenly divide the shape
60-
if (shape % subgroupData != 0)
61-
return false;
62-
}
63-
return true;
64-
}
65-
6637
static std::pair<SmallVector<int64_t>, int>
6738
getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
6839
int count = 1;
@@ -393,7 +364,7 @@ struct WgToSgVectorBroadcastOp
393364
else
394365
return failure();
395366

396-
if (!isDistributable(sgLayout, sgShape, wgShape))
367+
if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
397368
return failure();
398369

399370
// Check if the srcShape has unit dim in dimensions being broadcasted,

0 commit comments

Comments
 (0)