@@ -47,6 +47,20 @@ namespace imex {
47
47
48
48
namespace imex {
49
49
50
+ static xetile::WorkGroupMapAttr getWorkGroupMapAttr (Value val) {
51
+ auto defOp = val.getDefiningOp ();
52
+ if (!defOp)
53
+ return nullptr ;
54
+ if (auto ld = dyn_cast<xetile::LoadTileOp>(defOp)) {
55
+ return ld.getTile ().getType ().getWgMap ();
56
+ }
57
+ if (defOp->hasAttr (" map" ))
58
+ return defOp->getAttrOfType <xetile::WorkGroupMapAttr>(" map" );
59
+ if (defOp->hasAttr (" wg_map_c" ))
60
+ return defOp->getAttrOfType <xetile::WorkGroupMapAttr>(" wg_map_c" );
61
+ return nullptr ;
62
+ }
63
+
50
64
// This pass transform the Ops at WG level to SG level using the
51
65
// decomposition attributes provided by wg_map.
52
66
// clang-format off
@@ -573,15 +587,6 @@ class WGToSGXeTileConvertLayout
573
587
return rewriter.create <arith::ConstantIndexOp>(loc, value);
574
588
};
575
589
576
- // get the workgroup map attribute for a value from its defining op.
577
- auto getWorkGroupMapAttr = [&](Value val) {
578
- auto defOp = val.getDefiningOp ();
579
- if (auto ld = dyn_cast<xetile::LoadTileOp>(defOp)) {
580
- return ld.getTile ().getType ().getWgMap ();
581
- }
582
- return defOp->getAttrOfType <xetile::WorkGroupMapAttr>(" map" );
583
- };
584
-
585
590
auto isOneUseTranspose = [&](Operation *op) {
586
591
return isa<xetile::TransposeOp, vector::TransposeOp>(op) && op->hasOneUse ();
587
592
};
@@ -1089,6 +1094,17 @@ class WGToSGMathFPowIOpPattern : public OpConversionPattern<mlir::math::FPowIOp>
1089
1094
}
1090
1095
};
1091
1096
1097
+ class UnrealizedConversionCastOpPattern : public OpConversionPattern <mlir::UnrealizedConversionCastOp> {
1098
+ using OpConversionPattern<mlir::UnrealizedConversionCastOp>::OpConversionPattern;
1099
+
1100
+ mlir::LogicalResult
1101
+ matchAndRewrite (mlir::UnrealizedConversionCastOp op, OpAdaptor adaptor,
1102
+ ConversionPatternRewriter &rewriter) const override {
1103
+ rewriter.replaceOp (op, adaptor.getOperands ());
1104
+ return mlir::success ();
1105
+ }
1106
+ };
1107
+
1092
1108
// This function traverses backwards through loop-carried dependencies in SCF
1093
1109
// `for` loops to find the original (pre-loop) value.
1094
1110
static Value getPreLoopValue (Value val) {
@@ -1185,7 +1201,7 @@ void populateXeTileWgToSgPatterns(mlir::RewritePatternSet &patterns,
1185
1201
WGToSGArithSelectOpPattern, WGToSGMathFPowIOpPattern,
1186
1202
WGToSGVectorShapeCast, WGToSGVectorMultiDimReductionOp,
1187
1203
WGToSGLoadGatherOpPattern, WGToSGStoreScatterOpPattern,
1188
- WGToSGVectorCreateMask>(patterns.getContext ());
1204
+ WGToSGVectorCreateMask, UnrealizedConversionCastOpPattern >(patterns.getContext ());
1189
1205
patterns.insert <WGToSGElementWiseOpSameArgAndResultTypePattern<math::ExpOp, 1 >,
1190
1206
WGToSGElementWiseOpSameArgAndResultTypePattern<math::SqrtOp, 1 >,
1191
1207
WGToSGElementWiseOpSameArgAndResultTypePattern<arith::AddFOp, 2 >,
@@ -1232,7 +1248,7 @@ class XeTileWgToSgPass
1232
1248
1233
1249
void runOnOperation () override {
1234
1250
MLIRContext &context = getContext ();
1235
- auto mod = this -> getOperation ();
1251
+ auto mod = getOperation ();
1236
1252
1237
1253
// skip functions with XeTile.TileType inputs and outputs
1238
1254
if (!isSupportedModule (mod)) {
@@ -1241,9 +1257,86 @@ class XeTileWgToSgPass
1241
1257
return signalPassFailure ();
1242
1258
}
1243
1259
1244
- Operation *op = getOperation ();
1245
1260
// Run the analysis to find the candidates for the transformation
1246
- analyzeTransposeOps (op, sgLayoutMap);
1261
+ analyzeTransposeOps (mod, sgLayoutMap);
1262
+
1263
+ { // temporary change the VectorType to RankedTensorType for Structure Control Flow operands
1264
+ mlir::TypeConverter converter;
1265
+ converter.addConversion ([&](Type type) -> Type { return type; });
1266
+ converter.addConversion ([&](VectorType type) -> Type {
1267
+ auto newTy = RankedTensorType::get (type.getShape (), type.getElementType ());
1268
+ return newTy;
1269
+ });
1270
+
1271
+ auto materializeCast = [&](mlir::OpBuilder &builder, mlir::Type type,
1272
+ mlir::ValueRange inputs,
1273
+ mlir::Location loc) -> mlir::Value {
1274
+ if (inputs.size () != 1 )
1275
+ return nullptr ;
1276
+ return builder.create <UnrealizedConversionCastOp>(loc, type, inputs).getResult (0 );
1277
+ };
1278
+ converter.addSourceMaterialization (materializeCast);
1279
+ converter.addTargetMaterialization (materializeCast);
1280
+
1281
+ mlir::ConversionTarget target (context);
1282
+ target.addLegalOp <UnrealizedConversionCastOp>();
1283
+
1284
+ mlir::RewritePatternSet patterns (&context);
1285
+ scf::populateSCFStructuralTypeConversionsAndLegality (converter, patterns, target);
1286
+ (void )mlir::applyPartialConversion (mod, target, std::move (patterns));
1287
+
1288
+ // propagate the layout info into the RankedTensorType result for cast ops
1289
+ mod->walk ([&](UnrealizedConversionCastOp castOp) {
1290
+ if (castOp.getNumOperands () != 1 || castOp.getNumResults () != 1 )
1291
+ return WalkResult::skip ();
1292
+
1293
+ auto input = castOp.getInputs ()[0 ];
1294
+ auto result = castOp.getResults ()[0 ];
1295
+ auto inputTy = dyn_cast<VectorType>(input.getType ());
1296
+ auto resultTy = dyn_cast<RankedTensorType>(result.getType ());
1297
+
1298
+ // Only look at ops casting from VectorType to RankedTensorType
1299
+ if (!isa<VectorType>(inputTy) || !isa<RankedTensorType>(resultTy))
1300
+ return WalkResult::skip ();
1301
+
1302
+ auto wgMap = getWorkGroupMapAttr (input);
1303
+ if (wgMap) {
1304
+ auto newTy = resultTy.cloneWithEncoding (wgMap);
1305
+ result.setType (newTy);
1306
+
1307
+ // update the arguments if user is a LoopLike op.
1308
+ for (OpOperand &use : result.getUses ()) {
1309
+ if (auto loop = dyn_cast<LoopLikeOpInterface>(use.getOwner ())) {
1310
+ auto arg = loop.getTiedLoopRegionIterArg (&use);
1311
+ arg.setType (newTy);
1312
+ }
1313
+ // whileOp has two regions, the BlockArgument of the after region
1314
+ // is not exposed by LoopLikeOpInterface
1315
+ if (auto whileOp = dyn_cast<scf::WhileOp>(use.getOwner ())) {
1316
+ auto idx = use.getOperandNumber ();
1317
+ auto arg = whileOp.getAfterArguments ()[idx];
1318
+ arg.setType (newTy);
1319
+ }
1320
+ }
1321
+ return WalkResult::advance ();
1322
+ }
1323
+ return WalkResult::skip ();
1324
+ });
1325
+
1326
+ // using yieldOp as anchor to update the result type of its ParentOp
1327
+ mod->walk ([&](scf::YieldOp yieldOp) {
1328
+ auto parentOp = yieldOp->getParentOp ();
1329
+ for (auto r: parentOp->getOpResults ()) {
1330
+ auto idx = r.getResultNumber ();
1331
+ auto resultTy = r.getType ();
1332
+ auto yieldTy = yieldOp.getResults ()[idx].getType ();
1333
+ if (isa<RankedTensorType>(resultTy) && yieldTy != resultTy)
1334
+ r.setType (yieldTy);
1335
+ }
1336
+ });
1337
+
1338
+ }
1339
+
1247
1340
mlir::ConversionTarget target (context);
1248
1341
mlir::RewritePatternSet patterns (&context);
1249
1342
@@ -1343,7 +1436,9 @@ class XeTileWgToSgPass
1343
1436
mlir::TypeConverter converter;
1344
1437
target.addIllegalOp <xetile::ConvertLayoutOp>();
1345
1438
1346
- target.markUnknownOpDynamicallyLegal ([](Operation *) { return true ; });
1439
+ target.markUnknownOpDynamicallyLegal ([](Operation *op) {
1440
+ return !isa<mlir::UnrealizedConversionCastOp>(op);
1441
+ });
1347
1442
1348
1443
populateXeTileWgToSgPatterns (patterns, sgLayoutMap);
1349
1444
@@ -1354,71 +1449,25 @@ class XeTileWgToSgPass
1354
1449
// handle the conversion of the vector/tile type of same shape
1355
1450
// mapped to different sgData for region ops.
1356
1451
// TODO : Fix the type converter to handle such case.
1357
- converter.addConversion ([op](Type type) -> Type {
1358
- Type resultType = type;
1359
- auto vecType = dyn_cast<VectorType>(type);
1360
- auto tileTy = dyn_cast<xetile::TileType>(type);
1361
- if (!vecType && !tileTy) return resultType;
1362
-
1363
- op->walk ([&](Operation *op) {
1364
- auto isResultType = [&](Value value, Type valueType) {
1365
- if (valueType != type) return false ;
1366
- if (tileTy) {
1367
- if (!tileTy.getWgMap ()) return false ;
1368
- auto newShape = tileTy.getWgMap ().getSgData ();
1369
- resultType = xetile::TileType::get ({newShape[0 ], newShape[1 ]}, tileTy.getElementType ());
1370
- return true ;
1371
- }
1372
- if (vecType) {
1373
- Operation *defOp = value.getDefiningOp ();
1374
- if (!defOp) return false ;
1375
- if (auto ld = dyn_cast<xetile::LoadTileOp>(defOp)) {
1376
- auto mapAttr = ld.getTile ().getType ().getWgMap ();
1377
- if (mapAttr) {
1378
- auto newShape = mapAttr.getSgData ();
1379
- resultType = VectorType::get ({newShape[0 ], newShape[1 ]}, vecType.getElementType ());
1380
- return true ;
1381
- }
1382
- } else {
1383
- auto mapAttr = defOp->getAttrOfType <xetile::WorkGroupMapAttr>(" map" );
1384
- if (mapAttr) {
1385
- auto newShape = mapAttr.getSgData ();
1386
- resultType = VectorType::get ({newShape[0 ], newShape[1 ]}, vecType.getElementType ());
1387
- return true ;
1388
- }
1389
- }
1390
- }
1391
- return false ;
1392
- };
1452
+ converter.addConversion ([&](Type type) -> Type { return type; });
1453
+ converter.addConversion ([&](xetile::TileType type) -> Type {
1454
+ if (auto wgMap = type.getWgMap ()) {
1455
+ auto sgData = wgMap.getSgData ();
1456
+ auto newTy = xetile::TileType::get ({sgData[0 ], sgData[1 ]}, type.getElementType ());
1457
+ return newTy;
1458
+ }
1459
+ return type;
1460
+ });
1461
+
1462
+ converter.addConversion ([&](RankedTensorType type) -> Type {
1463
+ auto mapAttr = llvm::dyn_cast_or_null<xetile::WorkGroupMapAttr>(type.getEncoding ());
1464
+
1465
+ if (!mapAttr)
1466
+ return VectorType::get (type.getShape (), type.getElementType ());
1467
+
1468
+ auto sgData = llvm::to_vector_of<int64_t >(mapAttr.getSgData ().asArrayRef ());
1469
+ return VectorType::get (sgData, type.getElementType ());
1393
1470
1394
- if (auto forOp = dyn_cast<scf::ForOp>(op)) {
1395
- for (Value iterArg : forOp.getInitArgs ()) {
1396
- if (isResultType (iterArg, iterArg.getType ())) {
1397
- return WalkResult::interrupt ();
1398
- }
1399
- }
1400
- }
1401
- if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
1402
- for (auto yieldOp : ifOp.getThenRegion ().getOps <scf::YieldOp>()) {
1403
- for (Value yieldOperand : yieldOp.getOperands ()) {
1404
- if (isResultType (yieldOperand, yieldOperand.getType ())) {
1405
- return WalkResult::interrupt ();
1406
- }
1407
- }
1408
- }
1409
- if (!ifOp.getElseRegion ().empty ()) {
1410
- for (auto yieldOp : ifOp.getElseRegion ().getOps <scf::YieldOp>()) {
1411
- for (Value yieldOperand : yieldOp.getOperands ()) {
1412
- if (isResultType (yieldOperand, yieldOperand.getType ())) {
1413
- return WalkResult::interrupt ();
1414
- }
1415
- }
1416
- }
1417
- }
1418
- }
1419
- return WalkResult::advance ();
1420
- });
1421
- return resultType;
1422
1471
});
1423
1472
1424
1473
target.addDynamicallyLegalOp <scf::ForOp, scf::IfOp, scf::YieldOp>(
@@ -1440,8 +1489,8 @@ class XeTileWgToSgPass
1440
1489
if (mlir::failed (
1441
1490
mlir::applyPartialConversion (mod, target, std::move (patterns))))
1442
1491
return signalPassFailure ();
1443
- }
1444
- };
1492
+ }
1493
+ };
1445
1494
// / Create a pass
1446
1495
std::unique_ptr<Pass> createXeTileWgToSgPass () {
1447
1496
return std::make_unique<XeTileWgToSgPass>();
0 commit comments