Skip to content

Commit 7b5e8f1

Browse files
committed
partial working
1 parent ab448a3 commit 7b5e8f1

File tree

2 files changed

+21
-14
lines changed

2 files changed

+21
-14
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,18 +120,22 @@ bool XeGPUInstructionlizePass::needsUnroll(Operation *op) const {
120120

121121
void XeGPUInstructionlizePass::runOnOperation() {
122122
MLIRContext *ctx = &getContext();
123+
Operation *op = getOperation();
124+
125+
// first perform type conversion for SCF control folow ops
126+
xegpu::doSCFStructuralTypeConversionWithTensorType(op);
127+
123128
xegpu::UnrollOptions options;
124129
options.setFilterConstraint([&](Operation *op) -> LogicalResult {
125130
return needsUnroll(op) ? success() : failure();
126131
});
127132

128-
options.setNativeShapeFn(
129-
[&](Operation *op) -> std::optional<SmallVector<int64_t>> {
133+
options.setNativeShapeFn([&](Operation *op) {
130134
return getTileShape(op);
131135
});
132136

133137
options.setUnrolledTypesFn(
134-
[&](ShapedType type, ArrayRef<int64_t> tileShape) -> SmallVector<Type> {
138+
[&](ShapedType type, ArrayRef<int64_t> tileShape) {
135139
Type elemTy = type.getElementType();
136140
Type newTy;
137141

@@ -149,8 +153,10 @@ void XeGPUInstructionlizePass::runOnOperation() {
149153
return SmallVector<Type>(computeProduct(*ratio), newTy);
150154
});
151155

152-
RewritePatternSet patterns(ctx);
156+
GreedyRewriteConfig config;
157+
config.setStrictness(GreedyRewriteStrictness::ExistingOps);
153158

159+
RewritePatternSet patterns(ctx);
154160
populateXeGPUUnrollPatterns(patterns, options);
155-
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
161+
(void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
156162
}

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -215,17 +215,16 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
215215
// LayoutAttr
216216

217217
auto computeTileShapeAndCount = [&](ArrayRef<int64_t> shape,
218-
DenseI32ArrayAttr sgDataAttr,
219-
DenseI32ArrayAttr sgLayoutAttr) {
218+
DenseI32ArrayAttr sgDataAttr,
219+
DenseI32ArrayAttr sgLayoutAttr) {
220220
SmallVector<int64_t> tileShape;
221221
auto sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
222222
if (sgDataAttr)
223223
tileShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
224224
else
225225
tileShape = computeShapeRatio(shape, sgLayout).value_or(tileShape);
226226
assert(tileShape.size() && "failed to compute tileShape");
227-
SmallVector<int64_t> distUnit =
228-
computeElementwiseMul(sgLayout, tileShape);
227+
SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, tileShape);
229228
int count = computeProduct(shape) / computeProduct(distUnit);
230229
return std::make_pair(tileShape, count);
231230
};
@@ -249,8 +248,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
249248
if (layout.isWgLayout()) {
250249
// for WgToSg, the subShape is either from sgData or computed as
251250
// shape/sgLayout
252-
std::tie(subShape, count) = computeTileShapeAndCount(
253-
shape, layout.getSgData(), layout.getSgLayout());
251+
std::tie(subShape, count) = computeTileShapeAndCount(shape, layout.getSgData(), layout.getSgLayout());
254252
} else if (DenseI32ArrayAttr instData = layout.getInstData()) {
255253
// for unrolling, the subShape is determined by inst_data
256254
subShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
@@ -280,8 +278,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
280278
if (layout.isWgLayout()) {
281279
// for WgToSg, the subShape is either from sgData or computed as
282280
// shape/sgLayout
283-
std::tie(subShape, count) = computeTileShapeAndCount(
284-
shape, layout.getSgData(), layout.getSgLayout());
281+
std::tie(subShape, count) = computeTileShapeAndCount(shape, layout.getSgData(), layout.getSgLayout());
285282
layout = layout.dropSgLayoutAndData();
286283
} else if (DenseI32ArrayAttr instData = layout.getInstData()) {
287284
// for unrolling, the subShape is determined by inst_data
@@ -298,7 +295,11 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
298295
});
299296

300297
converter.addSourceMaterialization(materializeCast);
301-
converter.addTargetMaterialization(materializeCast);
298+
converter.addTargetMaterialization([&](OpBuilder &builder, TypeRange type,
299+
ValueRange inputs, Location loc) {
300+
return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
301+
.getResults();
302+
});
302303

303304
mlir::ConversionTarget target(*context);
304305
target.addLegalOp<UnrealizedConversionCastOp>();

0 commit comments

Comments
 (0)