@@ -34,6 +34,35 @@ using namespace mlir;
3434
3535namespace {
3636
37+ bool isDistributable (ArrayRef<int64_t > sgLayout, ArrayRef<int64_t > sgData,
38+ ArrayRef<int64_t > wgShape) {
39+ // Check rank consistency
40+ if (sgLayout.size () != sgData.size () || sgLayout.size () != wgShape.size ())
41+ return false ;
42+
43+ for (size_t i = 0 ; i < sgLayout.size (); ++i) {
44+ int64_t subgroupCount = sgLayout[i];
45+ int64_t subgroupData = sgData[i];
46+ int64_t shape = wgShape[i];
47+
48+ // Each subgroup must have positive data size
49+ if (subgroupData <= 0 || subgroupCount <= 0 || shape <= 0 )
50+ return false ;
51+
52+ // Total data assigned to all subgroups in this dimension
53+ int64_t totalSubgroupData = subgroupCount * subgroupData;
54+
55+ // Subgroups must not collectively exceed the shape
56+ if (totalSubgroupData > shape)
57+ return false ;
58+
59+ // Each subgroup's data must evenly divide the shape
60+ if (shape % subgroupData != 0 )
61+ return false ;
62+ }
63+ return true ;
64+ }
65+
3766static std::pair<SmallVector<int64_t >, int >
3867getSgShapeAndCount (ArrayRef<int64_t > shape, xegpu::LayoutAttr layout) {
3968 int count = 1 ;
@@ -357,6 +386,16 @@ struct WgToSgVectorBroadcastOp
357386 VectorType newResultType =
358387 VectorType::get (sgShape, resultType.getElementType ());
359388
389+ // Check if the output layout is distributable
390+ SmallVector<int64_t > sgLayout;
391+ if (auto sgLayoutAttr = layout.getSgLayout ())
392+ sgLayout = llvm::to_vector_of<int64_t >(sgLayoutAttr.asArrayRef ());
393+ else
394+ return failure ();
395+
396+ if (!isDistributable (sgLayout, sgShape, wgShape))
397+ return failure ();
398+
360399 // Check if the srcShape has unit dim in dimensions being broadcasted,
361400 // and the other dimensions are the same as the destination type
362401 // TODO: Generalize it
0 commit comments