66//
77// ===----------------------------------------------------------------------===//
88
9+ #include " mlir/Dialect/Utils/IndexingUtils.h"
910#include " mlir/Dialect/XeGPU/IR/XeGPU.h"
1011#include " mlir/IR/Builders.h"
1112#include " mlir/IR/DialectImplementation.h"
@@ -30,6 +31,61 @@ void XeGPUDialect::initialize() {
3031 >();
3132}
3233
34+ bool XeGPUDialect::isEvenlyDistributable (llvm::ArrayRef<int64_t > shape,
35+ xegpu::LayoutAttr attr) {
36+ assert (attr && " Layout attribute is missing." );
37+
38+ auto getSubShapeOrNull =
39+ [&](llvm::ArrayRef<int64_t > shape, DenseI32ArrayAttr layout,
40+ DenseI32ArrayAttr data,
41+ bool use_rr = true ) -> std::optional<SmallVector<int64_t >> {
42+ llvm::SmallVector<int64_t > newShape (shape);
43+ if (layout) {
44+ auto vec = llvm::to_vector_of<int64_t >(layout.asArrayRef ());
45+ if (vec.size () != shape.size ())
46+ return std::nullopt ;
47+ auto ratio = computeShapeRatio (shape, vec);
48+ if (!ratio.has_value ())
49+ return std::nullopt ;
50+ newShape = ratio.value ();
51+ }
52+
53+ if (data) {
54+ auto vec = llvm::to_vector_of<int64_t >(data.asArrayRef ());
55+ if (vec.size () != shape.size ())
56+ return std::nullopt ;
57+ auto ratio = computeShapeRatio (newShape, vec);
58+ if (!ratio.has_value () && use_rr)
59+ ratio = computeShapeRatio (vec, newShape);
60+ if (!ratio.has_value ())
61+ return std::nullopt ;
62+
63+ // if data is not null, we always return it for next phase.
64+ newShape = vec;
65+ }
66+ return newShape;
67+ };
68+
69+ // check the sgLayout and sgData
70+ auto maybeSgShape =
71+ getSubShapeOrNull (shape, attr.getSgLayout (), attr.getSgData ());
72+ if (!maybeSgShape)
73+ return false ;
74+ auto sgShape = maybeSgShape.value ();
75+
76+ // check InstData, it neither have layout nor need round-robin
77+ auto maybeInstShape =
78+ getSubShapeOrNull (sgShape, nullptr , attr.getInstData (), false );
79+ if (!maybeInstShape)
80+ return false ;
81+ auto instShape = maybeInstShape.value ();
82+
83+ // check LaneLayout and LaneData
84+ auto maybeLaneShape = getSubShapeOrNull (instShape, attr.getLaneLayout (),
85+ attr.getLaneData (), false );
86+ return maybeLaneShape.has_value ();
87+ }
88+
3389// ===----------------------------------------------------------------------===//
3490// XeGPU_BlockTensorDescAttr
3591// ===----------------------------------------------------------------------===//
@@ -241,7 +297,7 @@ LogicalResult TensorDescType::verify(
241297 llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
242298 mlir::Attribute encoding, mlir::Attribute layout) {
243299 size_t rank = shape.size ();
244- // Low-pressure types are packed in 32-bit units.
300+ // Low-precision types are packed in 32-bit units.
245301 int32_t packingFactor = 32 / elementType.getIntOrFloatBitWidth ();
246302 if (rank != 1 && rank != 2 )
247303 return emitError () << " expected 1D or 2D tensor" ;
@@ -268,23 +324,21 @@ LogicalResult TensorDescType::verify(
268324 }
269325 }
270326
271- if ( auto blockAttr =
272- mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding) ) {
327+ auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
328+ if (blockAttr ) {
273329 MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace ();
274330 if (rank == 2 && memorySpaceAttr &&
275331 memorySpaceAttr.getValue () == MemorySpace::SLM)
276332 return emitError () << " SLM is not supported for 2D block tensor" ;
277333 }
278334
279- if ( auto layoutAttr = llvm::dyn_cast_if_present<LayoutAttr>(layout)) {
280-
335+ auto layoutAttr = llvm::dyn_cast_if_present<LayoutAttr>(layout);
336+ if (layoutAttr) {
281337 if (rank != (size_t )layoutAttr.getRank ())
282338 return emitError () << " expected layout rank to match tensor rank" ;
283339
284- ArrayRef<int32_t > laneLayout = layoutAttr.getLaneLayout ().asArrayRef ();
285- ArrayRef<int32_t > laneData = layoutAttr.getLaneData ().asArrayRef ();
286-
287- if (scatterAttr) {
340+ auto laneData = layoutAttr.getLaneData ();
341+ if (scatterAttr && laneData) {
288342 // Validate subgroup mapping rules for scattered tensors.
289343 // A work-item's slice of the tensor with shape [sg_size] or
290344 // [sg_size, chunk_size] will be [1] or [1, 32/element_ty_bit_width]
@@ -294,20 +348,19 @@ LogicalResult TensorDescType::verify(
294348 if (rank > 1 && laneData[0 ] != 1 )
295349 return emitError ()
296350 << " cannot map over non-contiguous scattered row elements" ;
297- if (laneData. back () != packingFactor)
351+ if (laneData[rank - 1 ] != packingFactor)
298352 return emitError () << " work item data mapping must match the number of "
299353 " contiguous elements" ;
300354 }
301355
302- for ( size_t i = 0 ; i < shape. size (); ++i ) {
303- uint32_t numElemPerWi = laneLayout[i] * laneData[i] ;
304- if (shape[i] < numElemPerWi || shape[i] % numElemPerWi != 0 )
305- return emitError () << " cannot distribute " << shape[i] << " over "
306- << laneLayout[i] << " work items with "
307- << laneData[i] << " elements each " ;
356+ if (! XeGPUDialect::isEvenlyDistributable (shape, layoutAttr) ) {
357+ std::string shapeStr ;
358+ llvm::raw_string_ostream stream (shapeStr);
359+ llvm::interleaveComma ( shape, stream);
360+ return emitError () << " cannot distribute [ " << shapeStr << " ] using "
361+ << layoutAttr ;
308362 }
309363 }
310-
311364 return success ();
312365}
313366
0 commit comments