Skip to content

Commit a028590

Browse files
authored
Update PropagateLayout pass (#837)
- added support for ExtractOp - rename it to apply-vnni-transform
1 parent c9a4ecc commit a028590

File tree

11 files changed

+1350
-24
lines changed

11 files changed

+1350
-24
lines changed

include/imex/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ std::unique_ptr<mlir::Pass> createRemoveTemporariesPass();
3131
std::unique_ptr<mlir::Pass> createVectorLinearizePass();
3232
std::unique_ptr<mlir::Pass> createPropagatePackedLayoutPass();
3333
std::unique_ptr<mlir::Pass> createRemoveSingleElemVectorPass();
34+
std::unique_ptr<mlir::Pass> createVnniTransformationPass();
3435

3536
#define GEN_PASS_DECL
3637
#include "imex/Transforms/Passes.h.inc"

include/imex/Transforms/Passes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,5 +146,14 @@ def RemoveSingleElemVector : Pass<"imex-remove-single-elem-vector"> {
146146
];
147147
}
148148

149+
def VnniTransformation : Pass<"imex-xegpu-apply-vnni-transformation"> {
150+
let summary = "apply vnni transformation for to B operand of dpas instructions if necessary.";
151+
let constructor = "imex::createVnniTransformationPass()";
152+
153+
let dependentDialects = [
154+
"::mlir::vector::VectorDialect"
155+
];
156+
}
157+
149158

150159
#endif // _IMEX_TRANSFORMS_PASSES_TD_INCLUDED_

lib/Conversion/XeGPUToVC/XeGPUToVC.cpp

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -140,28 +140,27 @@ static Value adjustBasePointer(ConversionPatternRewriter &rewriter,
140140
auto strides = mlir::getStridesAndOffset(memType).first;
141141
int64_t i = memType.getRank() - 1;
142142

143-
auto computeBase =
144-
[&](Value base) {
145-
for (; i >= 0; --i) {
146-
unsigned stride =
147-
strides[i] * memType.getElementType().getIntOrFloatBitWidth() / 8;
148-
auto factor = createIndexConstant(stride);
149-
auto offset = offsets.pop_back_val();
150-
Value offsetVal;
151-
152-
if (offset.is<Value>()) {
153-
offsetVal = offset.get<Value>();
154-
} else {
155-
offsetVal = createIndexConstant(
156-
llvm::cast<IntegerAttr>(offset.get<Attribute>()).getInt());
157-
}
158-
auto linearOffset =
159-
rewriter.create<arith::MulIOp>(loc, offsetVal, factor);
160-
base = rewriter.create<arith::AddIOp>(loc, base, linearOffset);
161-
}
143+
auto computeBase = [&](Value base) {
144+
for (; i >= 0; --i) {
145+
unsigned stride =
146+
strides[i] * memType.getElementType().getIntOrFloatBitWidth() / 8;
147+
auto factor = createIndexConstant(stride);
148+
auto offset = offsets.pop_back_val();
149+
Value offsetVal;
150+
151+
if (offset.is<Value>()) {
152+
offsetVal = offset.get<Value>();
153+
} else {
154+
offsetVal = createIndexConstant(
155+
llvm::cast<IntegerAttr>(offset.get<Attribute>()).getInt());
156+
}
157+
auto linearOffset =
158+
rewriter.create<arith::MulIOp>(loc, offsetVal, factor);
159+
base = rewriter.create<arith::AddIOp>(loc, base, linearOffset);
160+
}
162161

163-
return base;
164-
};
162+
return base;
163+
};
165164

166165
if (tileRank == 2 && memType.getRank() > 2) {
167166
// base address of plane for 2d: base addr of memref + offsets (starting

lib/Dialect/XeTile/IR/XeTileDialect.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,10 @@ mlir::LogicalResult XeTileAttr::verify(
120120

121121
if (order != mlir::DenseI32ArrayAttr() && order.size() != 2)
122122
emitError() << "expect integer array of size 2 for order";
123-
if (inner_blocks != mlir::DenseI64ArrayAttr() && (inner_blocks.size() > 0 && inner_blocks.size() != 2))
124-
emitError() << "expect integer array of size 2 for non empty inner_blocks attribute";
123+
if (inner_blocks != mlir::DenseI64ArrayAttr() &&
124+
(inner_blocks.size() > 0 && inner_blocks.size() != 2))
125+
emitError() << "expect integer array of size 2 for non empty inner_blocks "
126+
"attribute";
125127
return mlir::success();
126128
}
127129

lib/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ add_mlir_library(IMEXTransforms
1010
SetSPIRVAbiAttribute.cpp
1111
SetSPIRVCapabilities.cpp
1212
VectorLinearize.cpp
13+
VnniTransformation.cpp
1314

1415
ADDITIONAL_HEADER_DIRS
1516
${PROJECT_SOURCE_DIR}/imex/Transforms

0 commit comments

Comments
 (0)