@@ -31,11 +31,21 @@ void XeGPUDialect::initialize() {
3131 >();
3232}
3333
34+ // Checks if the given shape can be evenly distributed based on the layout
35+ // and data factors provided by the LayoutAttr.
3436bool XeGPUDialect::isEvenlyDistributable (llvm::ArrayRef<int64_t > shape,
3537 xegpu::LayoutAttr attr) {
3638 assert (attr && " Layout attribute is missing." );
3739
38- auto getSubShapeOrNull =
40+ // Checks whether the given shape can be evenly distributed using the specified
41+ // layout and data attributes. If successful, it returns the work size for each
42+ // compute unit; otherwise, it returns `std::nullopt`. The work size per compute
43+ // unit is calculated as follows:
44+ // - If `data` is null: newShape[i] = shape[i] / layout[i]
45+ // - If `data` is not null: newShape[i] = data[i]
46+ // When round-robin distribution (`use_rr`) is enabled, `shape[i]` can be smaller
47+ // than `layout[i] * data[i]`, allowing multiple compute units to share the data.
48+ auto tryDistribute =
3949 [&](llvm::ArrayRef<int64_t > shape, DenseI32ArrayAttr layout,
4050 DenseI32ArrayAttr data,
4151 bool use_rr = true ) -> std::optional<SmallVector<int64_t >> {
@@ -68,20 +78,20 @@ bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
6878
6979 // check the sgLayout and sgData
7080 auto maybeSgShape =
71- getSubShapeOrNull (shape, attr.getSgLayout (), attr.getSgData ());
81+ tryDistribute (shape, attr.getSgLayout (), attr.getSgData ());
7282 if (!maybeSgShape)
7383 return false ;
7484 auto sgShape = maybeSgShape.value ();
7585
7686 // check InstData, it neither have layout nor need round-robin
7787 auto maybeInstShape =
78- getSubShapeOrNull (sgShape, nullptr , attr.getInstData (), false );
88+ tryDistribute (sgShape, nullptr , attr.getInstData (), false );
7989 if (!maybeInstShape)
8090 return false ;
8191 auto instShape = maybeInstShape.value ();
8292
8393 // check LaneLayout and LaneData
84- auto maybeLaneShape = getSubShapeOrNull (instShape, attr.getLaneLayout (),
94+ auto maybeLaneShape = tryDistribute (instShape, attr.getLaneLayout (),
8595 attr.getLaneData (), false );
8696 return maybeLaneShape.has_value ();
8797}
0 commit comments