Skip to content

Commit 00ffa57

Browse files
committed
Temp commit to check isDiscardable
1 parent 425d677 commit 00ffa57

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,35 @@ using namespace mlir;
3434

3535
namespace {
3636

37+
bool isDistributable(ArrayRef<int64_t> sgLayout, ArrayRef<int64_t> sgData,
38+
ArrayRef<int64_t> wgShape) {
39+
// Check rank consistency
40+
if (sgLayout.size() != sgData.size() || sgLayout.size() != wgShape.size())
41+
return false;
42+
43+
for (size_t i = 0; i < sgLayout.size(); ++i) {
44+
int64_t subgroupCount = sgLayout[i];
45+
int64_t subgroupData = sgData[i];
46+
int64_t shape = wgShape[i];
47+
48+
// Each subgroup must have positive data size
49+
if (subgroupData <= 0 || subgroupCount <= 0 || shape <= 0)
50+
return false;
51+
52+
// Total data assigned to all subgroups in this dimension
53+
int64_t totalSubgroupData = subgroupCount * subgroupData;
54+
55+
// Subgroups must not collectively exceed the shape
56+
if (totalSubgroupData > shape)
57+
return false;
58+
59+
// Each subgroup's data must evenly divide the shape
60+
if (shape % subgroupData != 0)
61+
return false;
62+
}
63+
return true;
64+
}
65+
3766
static std::pair<SmallVector<int64_t>, int>
3867
getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
3968
int count = 1;
@@ -357,6 +386,16 @@ struct WgToSgVectorBroadcastOp
357386
VectorType newResultType =
358387
VectorType::get(sgShape, resultType.getElementType());
359388

389+
// Check if the output layout is distributable
390+
SmallVector<int64_t> sgLayout;
391+
if (auto sgLayoutAttr = layout.getSgLayout())
392+
sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
393+
else
394+
return failure();
395+
396+
if (!isDistributable(sgLayout, sgShape, wgShape))
397+
return failure();
398+
360399
// Check if the srcShape has unit dim in dimensions being broadcasted,
361400
// and the other dimensions are the same as the destination type
362401
// TODO: Generalize it

0 commit comments

Comments
 (0)