Skip to content

Commit 227f725

Browse files
authored
[WgToSg] Generalize the handling of Structure Control Flow Op (#1073)
1 parent a95a408 commit 227f725

File tree

2 files changed

+180
-81
lines changed

2 files changed

+180
-81
lines changed

lib/Dialect/XeTile/Transforms/WgToSg.cpp

Lines changed: 129 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,20 @@ namespace imex {
4747

4848
namespace imex {
4949

50+
static xetile::WorkGroupMapAttr getWorkGroupMapAttr(Value val) {
51+
auto defOp = val.getDefiningOp();
52+
if (!defOp)
53+
return nullptr;
54+
if (auto ld = dyn_cast<xetile::LoadTileOp>(defOp)) {
55+
return ld.getTile().getType().getWgMap();
56+
}
57+
if (defOp->hasAttr("map"))
58+
return defOp->getAttrOfType<xetile::WorkGroupMapAttr>("map");
59+
if (defOp->hasAttr("wg_map_c"))
60+
return defOp->getAttrOfType<xetile::WorkGroupMapAttr>("wg_map_c");
61+
return nullptr;
62+
}
63+
5064
// This pass transform the Ops at WG level to SG level using the
5165
// decomposition attributes provided by wg_map.
5266
// clang-format off
@@ -573,15 +587,6 @@ class WGToSGXeTileConvertLayout
573587
return rewriter.create<arith::ConstantIndexOp>(loc, value);
574588
};
575589

576-
// get the workgroup map attribute for a value from its defining op.
577-
auto getWorkGroupMapAttr = [&](Value val) {
578-
auto defOp = val.getDefiningOp();
579-
if (auto ld = dyn_cast<xetile::LoadTileOp>(defOp)) {
580-
return ld.getTile().getType().getWgMap();
581-
}
582-
return defOp->getAttrOfType<xetile::WorkGroupMapAttr>("map");
583-
};
584-
585590
auto isOneUseTranspose = [&](Operation *op) {
586591
return isa<xetile::TransposeOp, vector::TransposeOp>(op) && op->hasOneUse();
587592
};
@@ -1089,6 +1094,17 @@ class WGToSGMathFPowIOpPattern : public OpConversionPattern<mlir::math::FPowIOp>
10891094
}
10901095
};
10911096

1097+
class UnrealizedConversionCastOpPattern : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
1098+
using OpConversionPattern<mlir::UnrealizedConversionCastOp>::OpConversionPattern;
1099+
1100+
mlir::LogicalResult
1101+
matchAndRewrite(mlir::UnrealizedConversionCastOp op, OpAdaptor adaptor,
1102+
ConversionPatternRewriter &rewriter) const override {
1103+
rewriter.replaceOp(op, adaptor.getOperands());
1104+
return mlir::success();
1105+
}
1106+
};
1107+
10921108
// This function traverses backwards through loop-carried dependencies in SCF
10931109
// `for` loops to find the original (pre-loop) value.
10941110
static Value getPreLoopValue(Value val) {
@@ -1185,7 +1201,7 @@ void populateXeTileWgToSgPatterns(mlir::RewritePatternSet &patterns,
11851201
WGToSGArithSelectOpPattern, WGToSGMathFPowIOpPattern,
11861202
WGToSGVectorShapeCast, WGToSGVectorMultiDimReductionOp,
11871203
WGToSGLoadGatherOpPattern, WGToSGStoreScatterOpPattern,
1188-
WGToSGVectorCreateMask>(patterns.getContext());
1204+
WGToSGVectorCreateMask, UnrealizedConversionCastOpPattern>(patterns.getContext());
11891205
patterns.insert<WGToSGElementWiseOpSameArgAndResultTypePattern<math::ExpOp, 1>,
11901206
WGToSGElementWiseOpSameArgAndResultTypePattern<math::SqrtOp, 1>,
11911207
WGToSGElementWiseOpSameArgAndResultTypePattern<arith::AddFOp, 2>,
@@ -1232,7 +1248,7 @@ class XeTileWgToSgPass
12321248

12331249
void runOnOperation() override {
12341250
MLIRContext &context = getContext();
1235-
auto mod = this->getOperation();
1251+
auto mod = getOperation();
12361252

12371253
// skip functions with XeTile.TileType inputs and outputs
12381254
if (!isSupportedModule(mod)) {
@@ -1241,9 +1257,86 @@ class XeTileWgToSgPass
12411257
return signalPassFailure();
12421258
}
12431259

1244-
Operation *op = getOperation();
12451260
// Run the analysis to find the candidates for the transformation
1246-
analyzeTransposeOps(op, sgLayoutMap);
1261+
analyzeTransposeOps(mod, sgLayoutMap);
1262+
1263+
{ // temporary change the VectorType to RankedTensorType for Structure Control Flow operands
1264+
mlir::TypeConverter converter;
1265+
converter.addConversion([&](Type type) -> Type { return type; });
1266+
converter.addConversion([&](VectorType type) -> Type {
1267+
auto newTy = RankedTensorType::get(type.getShape(), type.getElementType());
1268+
return newTy;
1269+
});
1270+
1271+
auto materializeCast = [&](mlir::OpBuilder &builder, mlir::Type type,
1272+
mlir::ValueRange inputs,
1273+
mlir::Location loc) -> mlir::Value {
1274+
if (inputs.size() != 1)
1275+
return nullptr;
1276+
return builder.create<UnrealizedConversionCastOp>(loc, type, inputs).getResult(0);
1277+
};
1278+
converter.addSourceMaterialization(materializeCast);
1279+
converter.addTargetMaterialization(materializeCast);
1280+
1281+
mlir::ConversionTarget target(context);
1282+
target.addLegalOp<UnrealizedConversionCastOp>();
1283+
1284+
mlir::RewritePatternSet patterns(&context);
1285+
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, target);
1286+
(void)mlir::applyPartialConversion(mod, target, std::move(patterns));
1287+
1288+
// propagate the layout info into the RankedTensorType result for cast ops
1289+
mod->walk([&](UnrealizedConversionCastOp castOp) {
1290+
if (castOp.getNumOperands() != 1 || castOp.getNumResults() != 1)
1291+
return WalkResult::skip();
1292+
1293+
auto input = castOp.getInputs()[0];
1294+
auto result = castOp.getResults()[0];
1295+
auto inputTy = dyn_cast<VectorType>(input.getType());
1296+
auto resultTy = dyn_cast<RankedTensorType>(result.getType());
1297+
1298+
// Only look at ops casting from VectorType to RankedTensorType
1299+
if (!isa<VectorType>(inputTy) || !isa<RankedTensorType>(resultTy))
1300+
return WalkResult::skip();
1301+
1302+
auto wgMap = getWorkGroupMapAttr(input);
1303+
if (wgMap) {
1304+
auto newTy = resultTy.cloneWithEncoding(wgMap);
1305+
result.setType(newTy);
1306+
1307+
// update the arguments if user is a LoopLike op.
1308+
for (OpOperand &use : result.getUses()) {
1309+
if (auto loop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
1310+
auto arg = loop.getTiedLoopRegionIterArg(&use);
1311+
arg.setType(newTy);
1312+
}
1313+
// whileOp has two regions, the BlockArgument of the after region
1314+
// is not exposed by LoopLikeOpInterface
1315+
if (auto whileOp = dyn_cast<scf::WhileOp>(use.getOwner())) {
1316+
auto idx = use.getOperandNumber();
1317+
auto arg = whileOp.getAfterArguments()[idx];
1318+
arg.setType(newTy);
1319+
}
1320+
}
1321+
return WalkResult::advance();
1322+
}
1323+
return WalkResult::skip();
1324+
});
1325+
1326+
// using yieldOp as anchor to update the result type of its ParentOp
1327+
mod->walk([&](scf::YieldOp yieldOp) {
1328+
auto parentOp = yieldOp->getParentOp();
1329+
for (auto r: parentOp->getOpResults()) {
1330+
auto idx = r.getResultNumber();
1331+
auto resultTy = r.getType();
1332+
auto yieldTy = yieldOp.getResults()[idx].getType();
1333+
if (isa<RankedTensorType>(resultTy) && yieldTy != resultTy)
1334+
r.setType(yieldTy);
1335+
}
1336+
});
1337+
1338+
}
1339+
12471340
mlir::ConversionTarget target(context);
12481341
mlir::RewritePatternSet patterns(&context);
12491342

@@ -1343,7 +1436,9 @@ class XeTileWgToSgPass
13431436
mlir::TypeConverter converter;
13441437
target.addIllegalOp<xetile::ConvertLayoutOp>();
13451438

1346-
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
1439+
target.markUnknownOpDynamicallyLegal([](Operation *op) {
1440+
return !isa<mlir::UnrealizedConversionCastOp>(op);
1441+
});
13471442

13481443
populateXeTileWgToSgPatterns(patterns, sgLayoutMap);
13491444

@@ -1354,71 +1449,25 @@ class XeTileWgToSgPass
13541449
// handle the conversion of the vector/tile type of same shape
13551450
// mapped to different sgData for region ops.
13561451
// TODO : Fix the type converter to handle such case.
1357-
converter.addConversion([op](Type type) -> Type {
1358-
Type resultType = type;
1359-
auto vecType = dyn_cast<VectorType>(type);
1360-
auto tileTy = dyn_cast<xetile::TileType>(type);
1361-
if (!vecType && !tileTy) return resultType;
1362-
1363-
op->walk([&](Operation *op) {
1364-
auto isResultType = [&](Value value, Type valueType) {
1365-
if (valueType != type) return false;
1366-
if (tileTy) {
1367-
if (!tileTy.getWgMap()) return false;
1368-
auto newShape = tileTy.getWgMap().getSgData();
1369-
resultType = xetile::TileType::get({newShape[0], newShape[1]}, tileTy.getElementType());
1370-
return true;
1371-
}
1372-
if (vecType) {
1373-
Operation *defOp = value.getDefiningOp();
1374-
if (!defOp) return false;
1375-
if (auto ld = dyn_cast<xetile::LoadTileOp>(defOp)) {
1376-
auto mapAttr = ld.getTile().getType().getWgMap();
1377-
if (mapAttr) {
1378-
auto newShape = mapAttr.getSgData();
1379-
resultType = VectorType::get({newShape[0], newShape[1]}, vecType.getElementType());
1380-
return true;
1381-
}
1382-
} else {
1383-
auto mapAttr = defOp->getAttrOfType<xetile::WorkGroupMapAttr>("map");
1384-
if (mapAttr) {
1385-
auto newShape = mapAttr.getSgData();
1386-
resultType = VectorType::get({newShape[0], newShape[1]}, vecType.getElementType());
1387-
return true;
1388-
}
1389-
}
1390-
}
1391-
return false;
1392-
};
1452+
converter.addConversion([&](Type type) -> Type { return type; });
1453+
converter.addConversion([&](xetile::TileType type) -> Type {
1454+
if (auto wgMap = type.getWgMap()) {
1455+
auto sgData = wgMap.getSgData();
1456+
auto newTy = xetile::TileType::get({sgData[0], sgData[1]}, type.getElementType());
1457+
return newTy;
1458+
}
1459+
return type;
1460+
});
1461+
1462+
converter.addConversion([&](RankedTensorType type) -> Type {
1463+
auto mapAttr = llvm::dyn_cast_or_null<xetile::WorkGroupMapAttr>(type.getEncoding());
1464+
1465+
if (!mapAttr)
1466+
return VectorType::get(type.getShape(), type.getElementType());
1467+
1468+
auto sgData = llvm::to_vector_of<int64_t>(mapAttr.getSgData().asArrayRef());
1469+
return VectorType::get(sgData, type.getElementType());
13931470

1394-
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
1395-
for (Value iterArg : forOp.getInitArgs()) {
1396-
if (isResultType(iterArg, iterArg.getType())) {
1397-
return WalkResult::interrupt();
1398-
}
1399-
}
1400-
}
1401-
if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
1402-
for (auto yieldOp : ifOp.getThenRegion().getOps<scf::YieldOp>()) {
1403-
for (Value yieldOperand : yieldOp.getOperands()) {
1404-
if (isResultType(yieldOperand, yieldOperand.getType())) {
1405-
return WalkResult::interrupt();
1406-
}
1407-
}
1408-
}
1409-
if (!ifOp.getElseRegion().empty()) {
1410-
for (auto yieldOp : ifOp.getElseRegion().getOps<scf::YieldOp>()) {
1411-
for (Value yieldOperand : yieldOp.getOperands()) {
1412-
if (isResultType(yieldOperand, yieldOperand.getType())) {
1413-
return WalkResult::interrupt();
1414-
}
1415-
}
1416-
}
1417-
}
1418-
}
1419-
return WalkResult::advance();
1420-
});
1421-
return resultType;
14221471
});
14231472

14241473
target.addDynamicallyLegalOp<scf::ForOp, scf::IfOp, scf::YieldOp>(
@@ -1440,8 +1489,8 @@ class XeTileWgToSgPass
14401489
if (mlir::failed(
14411490
mlir::applyPartialConversion(mod, target, std::move(patterns))))
14421491
return signalPassFailure();
1443-
}
1444-
};
1492+
}
1493+
};
14451494
/// Create a pass
14461495
std::unique_ptr<Pass> createXeTileWgToSgPass() {
14471496
return std::make_unique<XeTileWgToSgPass>();

test/Dialect/XeTile/Transforms/WgToSg/unit_tests.mlir

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,5 +167,55 @@ gpu.module @test_arith_extf {
167167
%1 = xetile.init_tile %src[8, 16], [%dim0_size, %dim1_size], [%dim0_stride, %dim1_stride]
168168
: i64 -> !xetile.tile<128x64xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [4, 64]>>>
169169
gpu.return
170-
}
170+
}
171+
172+
gpu.func @test_while_vector_op(%cond1: i1, %cond2: i1) {
173+
%cst = arith.constant {map = #xetile.wg_map<sg_layout = [4, 8], sg_data = [32, 32]>} dense<1.0> : vector<128x256xf32>
174+
%zero = arith.constant {map = #xetile.wg_map<sg_layout = [2, 4], sg_data = [64, 64]>} dense<0.0> : vector<128x256xf32>
175+
176+
// CHECK: [[RES1:%.+]] = scf.while ({{.*}}) : (vector<32x32xf32>, i1) -> vector<32x32xf32>
177+
%result1 = scf.while (%arg0 = %cst, %arg1 = %cond1) : (vector<128x256xf32>, i1) -> (vector<128x256xf32>) {
178+
%cond = arith.andi %arg1, %cond1 : i1
179+
scf.condition(%cond) %arg0 : vector<128x256xf32>
180+
} do {
181+
^bb0(%arg0: vector<128x256xf32>):
182+
// CHECK:scf.yield {{.*}} : vector<32x32xf32>, i1
183+
scf.yield %cst, %cond1 : vector<128x256xf32>, i1
184+
}
185+
186+
// CHECK: [[RES2:%.+]] = scf.while ({{.*}}) : (vector<64x64xf32>, i1) -> vector<64x64xf32>
187+
%result2 = scf.while (%arg2 = %zero, %arg3 = %cond2) : (vector<128x256xf32>, i1) -> (vector<128x256xf32>) {
188+
%cond = arith.andi %arg3, %cond2 : i1
189+
scf.condition(%cond) %arg2 : vector<128x256xf32>
190+
} do {
191+
^bb0(%arg2: vector<128x256xf32>):
192+
// CHECK:scf.yield {{.*}} : vector<64x64xf32>, i1
193+
scf.yield %zero, %cond2 : vector<128x256xf32>, i1
194+
}
195+
196+
gpu.return
197+
}
198+
gpu.func @test_if_vector_op(%cond1: i1, %cond2: i1) {
199+
%cst = arith.constant {map = #xetile.wg_map<sg_layout = [4, 8], sg_data = [32, 32]>} dense<1.0> : vector<128x256xf32>
200+
%zero = arith.constant {map = #xetile.wg_map<sg_layout = [2, 4], sg_data = [64, 64]>} dense<0.0> : vector<128x256xf32>
201+
// CHECK: %[[RES1:.*]] = scf.if {{.*}} -> (vector<32x32xf32>)
202+
%result = scf.if %cond1 -> (vector<128x256xf32>) {
203+
// CHECK:scf.yield {{.*}} : vector<32x32xf32>
204+
scf.yield %cst : vector<128x256xf32>
205+
} else {
206+
// CHECK:scf.yield {{.*}} : vector<32x32xf32>
207+
scf.yield %cst : vector<128x256xf32>
208+
}
209+
210+
// CHECK: %[[RES2:.*]] = scf.if {{.*}} -> (vector<64x64xf32>)
211+
%result1 = scf.if %cond2 -> (vector<128x256xf32>) {
212+
// CHECK:scf.yield {{.*}} : vector<64x64xf32>
213+
scf.yield %zero : vector<128x256xf32>
214+
} else {
215+
// CHECK:scf.yield {{.*}} : vector<64x64xf32>
216+
scf.yield %zero : vector<128x256xf32>
217+
}
218+
gpu.return
219+
}
220+
171221
}

0 commit comments

Comments
 (0)