Skip to content

Commit 727390f

Browse files
committed
add comments
1 parent 15b1b46 commit 727390f

File tree

2 files changed

+34
-12
lines changed

2 files changed

+34
-12
lines changed

mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,23 @@ void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
4343
/// Appends patterns for XeGPU SIMT distribution into `patterns`.
4444
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
4545

46+
/// Collect a set of pattern to unroll xegpu operations to a smaller shapes.
47+
/// Users can control whether an operation to be unrolled or not, as well as
48+
/// the its target shape via `options` structure. (via setting filterConstraint
49+
/// and nativeShape respectively, both of them are function refs taking `op` as
50+
/// the input).
51+
/// An `op` is unrolled to the `targetShape` as follows, for each of its
52+
/// operands:
53+
/// 1. the unrolled type `unrolledType` and number of unrolled instances
54+
/// `numUnrolledInstances` are computed from the `targetShape`.
55+
/// 2. ExtractStridedSlice are created to break-up the vector operands. And
56+
/// BuildinUnrealizedCastop are created to break-up the TensorDesc operands.
57+
/// 3. the original op is cloned `numUnrolledInstances` times, once for each
58+
/// result.
59+
/// 4. InsertStridedSlice are inserted for VectorType result, and
60+
/// BuildinUnrealizedCastOp are inserted for TensorDescType result to
61+
/// re-assemble the slices into the original shape.
62+
///
4663
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns,
4764
const UnrollOptions &options);
4865

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1919
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
2020
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21+
#include "llvm/ADT/STLExtras.h"
2122
#include "llvm/Support/Debug.h"
2223
#include <numeric>
2324

@@ -79,8 +80,8 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
7980
return layout.dropInstData();
8081
};
8182

82-
SmallVector<Type> convertType(ShapedType type,
83-
ArrayRef<int64_t> blockSize) const {
83+
SmallVector<Type> getUnrolledTypes(ShapedType type,
84+
ArrayRef<int64_t> blockSize) const {
8485
auto elemTy = type.getElementType();
8586
Type newTy;
8687
// TensorDescType needs to drop the inst_data field in the layout attribute
@@ -99,8 +100,8 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
99100
return llvm::SmallVector<Type>(computeProduct(*ratio), newTy);
100101
}
101102

102-
// emulate the the unpack behavior using insert_strided_slice for VectorType
103-
// values and unrealized_conversion_cast for TileType values.
103+
/// emulate the the unpack behavior using insert_strided_slice for VectorType
104+
/// values and unrealized_conversion_cast for TileType values.
104105
Value unpack(ValueRange srcs, Type destTy, llvm::ArrayRef<int64_t> blockSize,
105106
Location loc, PatternRewriter &rewriter) const {
106107
if (auto vecTy = dyn_cast<VectorType>(destTy)) {
@@ -136,8 +137,8 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
136137
return Value();
137138
}
138139

139-
// emulate the the pack behavior using extract_strided_slice for VectorType
140-
// values and unrealized_conversion_cast for TensorDescType values.
140+
/// emulate the the pack behavior using extract_strided_slice for VectorType
141+
/// values and unrealized_conversion_cast for TensorDescType values.
141142
llvm::SmallVector<Value> pack(Value src, TypeRange destTypes,
142143
llvm::ArrayRef<int64_t> blockSize, Location loc,
143144
PatternRewriter &rewriter) const {
@@ -266,7 +267,7 @@ struct UnrollUpdateNdOffsetOp : public UnrollPattern<xegpu::UpdateNdOffsetOp> {
266267
return failure();
267268
auto grids = *maybeGrids;
268269

269-
auto convertedTdescTypes = convertType(tdescTy, targetShape);
270+
auto convertedTdescTypes = getUnrolledTypes(tdescTy, targetShape);
270271
auto convertedTdesc =
271272
pack(tdesc, convertedTdescTypes, targetShape, loc, rewriter);
272273

@@ -301,7 +302,7 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
301302
return failure();
302303
auto grids = *maybeGrids;
303304

304-
auto convertedTdescTypes = convertType(tdescTy, targetShape);
305+
auto convertedTdescTypes = getUnrolledTypes(tdescTy, targetShape);
305306
auto convertedTdesc =
306307
pack(tdesc, convertedTdescTypes, targetShape, loc, rewriter);
307308

@@ -340,7 +341,7 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
340341
auto elemTy = tdescTy.getElementType();
341342
auto newValueTy = valueTy.cloneWith(targetShape, elemTy);
342343

343-
auto convertedTdescTypes = convertType(tdescTy, targetShape);
344+
auto convertedTdescTypes = getUnrolledTypes(tdescTy, targetShape);
344345
auto convertedTdescs = pack(op.getTensorDesc(), convertedTdescTypes,
345346
targetShape, loc, rewriter);
346347

@@ -380,8 +381,8 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
380381
return failure();
381382
auto grids = *maybeGrids;
382383

383-
auto convertedValTypes = convertType(valueTy, targetShape);
384-
auto convertedTdescTypes = convertType(tdescTy, targetShape);
384+
auto convertedValTypes = getUnrolledTypes(valueTy, targetShape);
385+
auto convertedTdescTypes = getUnrolledTypes(tdescTy, targetShape);
385386

386387
auto convertedValues =
387388
pack(op.getValue(), convertedValTypes, targetShape, loc, rewriter);
@@ -448,8 +449,12 @@ struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
448449

449450
// skip the operation if every operand has an invalid blocking size (empty)
450451
// or if the original shape matches the blocking size (size == 1).
451-
if (aVals.size() <= 1 && bVals.size() <= 1 && cVals.size() <= 1)
452+
auto ranges = c ? SmallVector<ValueRange>({aVals, bVals, cVals})
453+
: SmallVector<ValueRange>({aVals, bVals});
454+
if (any_of(ranges, [](auto &v) { return v.size() == 0; }) ||
455+
all_of(ranges, [](auto &v) { return v.size() == 1; })) {
452456
return failure();
457+
}
453458

454459
auto resultTy = op.getResult().getType();
455460
auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());

0 commit comments

Comments
 (0)