1818#include " mlir/Dialect/XeGPU/IR/XeGPU.h"
1919#include " mlir/Dialect/XeGPU/Transforms/Transforms.h"
2020#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
21+ #include " llvm/ADT/STLExtras.h"
2122#include " llvm/Support/Debug.h"
2223#include < numeric>
2324
@@ -79,8 +80,8 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
7980 return layout.dropInstData ();
8081 };
8182
82- SmallVector<Type> convertType (ShapedType type,
83- ArrayRef<int64_t > blockSize) const {
83+ SmallVector<Type> getUnrolledTypes (ShapedType type,
84+ ArrayRef<int64_t > blockSize) const {
8485 auto elemTy = type.getElementType ();
8586 Type newTy;
8687 // TensorDescType needs to drop the inst_data field in the layout attribute
@@ -99,8 +100,8 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
99100 return llvm::SmallVector<Type>(computeProduct (*ratio), newTy);
100101 }
101102
102- // emulate the the unpack behavior using insert_strided_slice for VectorType
103- // values and unrealized_conversion_cast for TileType values.
103+ // / emulate the the unpack behavior using insert_strided_slice for VectorType
104+ // / values and unrealized_conversion_cast for TileType values.
104105 Value unpack (ValueRange srcs, Type destTy, llvm::ArrayRef<int64_t > blockSize,
105106 Location loc, PatternRewriter &rewriter) const {
106107 if (auto vecTy = dyn_cast<VectorType>(destTy)) {
@@ -136,8 +137,8 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
136137 return Value ();
137138 }
138139
139- // emulate the the pack behavior using extract_strided_slice for VectorType
140- // values and unrealized_conversion_cast for TensorDescType values.
140+ // / emulate the the pack behavior using extract_strided_slice for VectorType
141+ // / values and unrealized_conversion_cast for TensorDescType values.
141142 llvm::SmallVector<Value> pack (Value src, TypeRange destTypes,
142143 llvm::ArrayRef<int64_t > blockSize, Location loc,
143144 PatternRewriter &rewriter) const {
@@ -266,7 +267,7 @@ struct UnrollUpdateNdOffsetOp : public UnrollPattern<xegpu::UpdateNdOffsetOp> {
266267 return failure ();
267268 auto grids = *maybeGrids;
268269
269- auto convertedTdescTypes = convertType (tdescTy, targetShape);
270+ auto convertedTdescTypes = getUnrolledTypes (tdescTy, targetShape);
270271 auto convertedTdesc =
271272 pack (tdesc, convertedTdescTypes, targetShape, loc, rewriter);
272273
@@ -301,7 +302,7 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
301302 return failure ();
302303 auto grids = *maybeGrids;
303304
304- auto convertedTdescTypes = convertType (tdescTy, targetShape);
305+ auto convertedTdescTypes = getUnrolledTypes (tdescTy, targetShape);
305306 auto convertedTdesc =
306307 pack (tdesc, convertedTdescTypes, targetShape, loc, rewriter);
307308
@@ -340,7 +341,7 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
340341 auto elemTy = tdescTy.getElementType ();
341342 auto newValueTy = valueTy.cloneWith (targetShape, elemTy);
342343
343- auto convertedTdescTypes = convertType (tdescTy, targetShape);
344+ auto convertedTdescTypes = getUnrolledTypes (tdescTy, targetShape);
344345 auto convertedTdescs = pack (op.getTensorDesc (), convertedTdescTypes,
345346 targetShape, loc, rewriter);
346347
@@ -380,8 +381,8 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
380381 return failure ();
381382 auto grids = *maybeGrids;
382383
383- auto convertedValTypes = convertType (valueTy, targetShape);
384- auto convertedTdescTypes = convertType (tdescTy, targetShape);
384+ auto convertedValTypes = getUnrolledTypes (valueTy, targetShape);
385+ auto convertedTdescTypes = getUnrolledTypes (tdescTy, targetShape);
385386
386387 auto convertedValues =
387388 pack (op.getValue (), convertedValTypes, targetShape, loc, rewriter);
@@ -448,8 +449,12 @@ struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
448449
449450 // skip the operation if every operand has an invalid blocking size (empty)
450451 // or if the original shape matches the blocking size (size == 1).
451- if (aVals.size () <= 1 && bVals.size () <= 1 && cVals.size () <= 1 )
452+ auto ranges = c ? SmallVector<ValueRange>({aVals, bVals, cVals})
453+ : SmallVector<ValueRange>({aVals, bVals});
454+ if (any_of (ranges, [](auto &v) { return v.size () == 0 ; }) ||
455+ all_of (ranges, [](auto &v) { return v.size () == 1 ; })) {
452456 return failure ();
457+ }
453458
454459 auto resultTy = op.getResult ().getType ();
455460 auto vecTy = VectorType::get (cBlockSize, resultTy.getElementType ());
0 commit comments