@@ -34,35 +34,6 @@ using namespace mlir;
3434
3535namespace {
3636
37- bool isDistributable (ArrayRef<int64_t > sgLayout, ArrayRef<int64_t > sgData,
38- ArrayRef<int64_t > wgShape) {
39- // Check rank consistency
40- if (sgLayout.size () != sgData.size () || sgLayout.size () != wgShape.size ())
41- return false ;
42-
43- for (size_t i = 0 ; i < sgLayout.size (); ++i) {
44- int64_t subgroupCount = sgLayout[i];
45- int64_t subgroupData = sgData[i];
46- int64_t shape = wgShape[i];
47-
48- // Each subgroup must have positive data size
49- if (subgroupData <= 0 || subgroupCount <= 0 || shape <= 0 )
50- return false ;
51-
52- // Total data assigned to all subgroups in this dimension
53- int64_t totalSubgroupData = subgroupCount * subgroupData;
54-
55- // Subgroups must not collectively exceed the shape
56- if (totalSubgroupData > shape)
57- return false ;
58-
59- // Each subgroup's data must evenly divide the shape
60- if (shape % subgroupData != 0 )
61- return false ;
62- }
63- return true ;
64- }
65-
6637static std::pair<SmallVector<int64_t >, int >
6738getSgShapeAndCount (ArrayRef<int64_t > shape, xegpu::LayoutAttr layout) {
6839 int count = 1 ;
@@ -393,7 +364,7 @@ struct WgToSgVectorBroadcastOp
393364 else
394365 return failure ();
395366
396- if (!isDistributable (sgLayout, sgShape, wgShape ))
367+ if (!xegpu::XeGPUDialect::isEvenlyDistributable (wgShape, layout ))
397368 return failure ();
398369
399370 // Check if the srcShape has unit dim in dimensions being broadcasted,
0 commit comments