Skip to content

Commit 6fef430

Browse files
committed
refine
1 parent 35b35f0 commit 6fef430

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -95,19 +95,19 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
9595

9696
// emulate the the unpack behavior using insert_strided_slice for VectorType
9797
// values and unrealized_conversion_cast for TileType values.
98-
Value unpack(ValueRange srcs, Type destTy, llvm::ArrayRef<int64_t> innerBlock,
98+
Value unpack(ValueRange srcs, Type destTy, llvm::ArrayRef<int64_t> blockSize,
9999
Location loc, PatternRewriter &rewriter) const {
100100
if (auto vecTy = dyn_cast<VectorType>(destTy)) {
101-
assert(vecTy.getRank() == 2 && innerBlock.size() == 2 &&
102-
"Expecting innerBlock size to match the rank of destTy.");
101+
assert(vecTy.getRank() == 2 && blockSize.size() == 2 &&
102+
"Expecting blockSize size to match the rank of destTy.");
103103
auto shape = vecTy.getShape();
104104
auto zeroAttr = rewriter.getZeroAttr(vecTy.getElementType());
105105

106106
Value result = rewriter.create<arith::ConstantOp>(
107107
loc, vecTy, DenseElementsAttr::get(vecTy, zeroAttr));
108108
int64_t idx = 0;
109-
for (int64_t i = 0; i < shape[0]; i += innerBlock[0]) {
110-
for (int64_t j = 0; j < shape[1]; j += innerBlock[1]) {
109+
for (int64_t i = 0; i < shape[0]; i += blockSize[0]) {
110+
for (int64_t j = 0; j < shape[1]; j += blockSize[1]) {
111111
result = rewriter.create<vector::InsertStridedSliceOp>(
112112
loc, srcs[idx++], result, llvm::ArrayRef<int64_t>({i, j}),
113113
llvm::ArrayRef<int64_t>({1, 1}));
@@ -120,7 +120,7 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
120120
rewriter.getUnitAttr());
121121
auto innerBlkAttr =
122122
NamedAttribute(rewriter.getStringAttr(blockAttrName),
123-
rewriter.getDenseI64ArrayAttr(innerBlock));
123+
rewriter.getDenseI64ArrayAttr(blockSize));
124124
auto castOp = rewriter.create<UnrealizedConversionCastOp>(
125125
loc, destTy, srcs,
126126
llvm::ArrayRef<NamedAttribute>({attr, innerBlkAttr}));
@@ -134,17 +134,17 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
134134
// emulate the the pack behavior using extract_strided_slice for VectorType
135135
// values and unrealized_conversion_cast for TensorDescType values.
136136
llvm::SmallVector<Value> pack(Value src, TypeRange destTypes,
137-
llvm::ArrayRef<int64_t> innerBlock,
138-
Location loc, PatternRewriter &rewriter) const {
137+
llvm::ArrayRef<int64_t> blockSize, Location loc,
138+
PatternRewriter &rewriter) const {
139139
if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
140-
assert(vecTy.getRank() == 2 && innerBlock.size() == 2 &&
141-
"Expecting innerBlock size to match the rank of src.");
140+
assert(vecTy.getRank() == 2 && blockSize.size() == 2 &&
141+
"Expecting blockSize size to match the rank of src.");
142142
auto shape = vecTy.getShape();
143143
llvm::SmallVector<Value> results;
144-
for (int64_t i = 0; i < shape[0]; i += innerBlock[0]) {
145-
for (int64_t j = 0; j < shape[1]; j += innerBlock[1]) {
144+
for (int64_t i = 0; i < shape[0]; i += blockSize[0]) {
145+
for (int64_t j = 0; j < shape[1]; j += blockSize[1]) {
146146
auto slice = rewriter.create<vector::ExtractStridedSliceOp>(
147-
loc, src, llvm::ArrayRef<int64_t>({i, j}), innerBlock,
147+
loc, src, llvm::ArrayRef<int64_t>({i, j}), blockSize,
148148
llvm::ArrayRef<int64_t>({1, 1}));
149149
results.push_back(slice);
150150
}
@@ -155,7 +155,7 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
155155
rewriter.getUnitAttr());
156156
auto innerBlkAttr =
157157
NamedAttribute(rewriter.getStringAttr(blockAttrName),
158-
rewriter.getDenseI64ArrayAttr(innerBlock));
158+
rewriter.getDenseI64ArrayAttr(blockSize));
159159
auto castOp = rewriter.create<UnrealizedConversionCastOp>(
160160
loc, destTypes, src,
161161
llvm::ArrayRef<NamedAttribute>({attr, innerBlkAttr}));

0 commit comments

Comments
 (0)