3
3
#include " mlir/IR/BuiltinAttributes.h"
4
4
#include " mlir/IR/Dominance.h"
5
5
#include " mlir/IR/IRMapping.h"
6
+ #include " mlir/IR/Matchers.h"
6
7
#include " mlir/IR/PatternMatch.h"
7
8
#include " mlir/IR/Verifier.h"
9
+ #include " mlir/Interfaces/InferTypeOpInterface.h"
8
10
#include " mlir/Interfaces/SideEffectInterfaces.h"
11
+ #include " mlir/Pass/Pass.h"
9
12
#include " mlir/Pass/PassManager.h"
10
13
#include " mlir/Support/LogicalResult.h"
11
14
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
15
+ #include " mlir/Transforms/Passes.h"
16
+ #include " mlir/Transforms/RegionUtils.h"
12
17
13
18
#include " intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
14
19
#include " intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"
17
22
18
23
#include " triton/Analysis/Utility.h"
19
24
#include " triton/Dialect/Triton/IR/Dialect.h"
20
- #include " triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
25
+ #include " triton/Dialect/TritonGPU/IR/Dialect.h"
26
+ #include " triton/Dialect/TritonGPU/Transforms/Passes.h"
27
+ #include " triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
21
28
#include " triton/Dialect/TritonGPU/Transforms/Utility.h"
22
29
#include < deque>
23
30
@@ -106,6 +113,8 @@ class LayoutPropagation {
106
113
// Return the mapped value in the given encoding. This will insert a convert
107
114
// if the encoding is different than the encoding decided at resolve time.
108
115
Value getValueAs (Value value, Attribute encoding);
116
+ // Return the original value mapped to the new desired encoding.
117
+ Value getRewrittenValue (Value value);
109
118
// Dump the current stage of layout information.
110
119
void dump ();
111
120
@@ -190,7 +199,7 @@ bool isLayoutAnchor(Operation *op) {
190
199
return ttgi::isExpensiveLoadOrStore (op);
191
200
// TODO: we should estimate the cost of the not propagating layout for
192
201
// AtomicCAS for further performance consideration.
193
- if (isa<DotOp, AtomicCASOp>(op))
202
+ if (isa<DotOp, DotScaledOp, AtomicCASOp>(op))
194
203
return true ;
195
204
if (isa<AtomicRMWOp>(op))
196
205
if (auto tensorType =
@@ -304,6 +313,15 @@ SmallVector<Value> LayoutPropagation::propagateToUsers(Value value,
304
313
setEncoding (user->getResults (), info, changed, user);
305
314
continue ;
306
315
}
316
+ if (auto gatherOp = dyn_cast<GatherOp>(user)) {
317
+ // Propagate the layout through the indices only, and if the layout does
318
+ // not have an efficient layout set.
319
+ if (!gatherOp.getEfficientLayout () &&
320
+ &use == &gatherOp.getIndicesMutable ()) {
321
+ setEncoding (gatherOp.getResult (), info, changed, user);
322
+ continue ;
323
+ }
324
+ }
307
325
if (auto storeOp = dyn_cast<StoreOp>(user)) {
308
326
auto checkMMAorMMADerived = [](Attribute encoding) {
309
327
bool isMMAorMMADerived = isa<MmaEncodingTrait>(encoding);
@@ -339,7 +357,7 @@ SmallVector<Value> LayoutPropagation::propagateToUsers(Value value,
339
357
340
358
void LayoutPropagation::propagateLayout () {
341
359
SmallVector<Value> queue;
342
- for (const auto & it : layouts) {
360
+ for (auto it : layouts) {
343
361
queue.push_back (it.first );
344
362
}
345
363
while (!queue.empty ()) {
@@ -353,6 +371,7 @@ void LayoutPropagation::propagateLayout() {
353
371
<< info.encodings .size () << " candidate encoding(s):\n " ;
354
372
for (Attribute encoding : info.encodings )
355
373
DBGS () << " " << encoding << " \n " ;
374
+ DBGS () << " changed: " << changed.size () << " \n " ;
356
375
});
357
376
358
377
queue.insert (queue.end (), changed.begin (), changed.end ());
@@ -469,22 +488,25 @@ void LayoutPropagation::map(Value old, Value newV) {
469
488
newV;
470
489
}
471
490
491
+ Value LayoutPropagation::getRewrittenValue (Value value) {
492
+ auto tensorType = dyn_cast<RankedTensorType>(value.getType ());
493
+ if (!tensorType)
494
+ return value;
495
+ auto layoutIt = layouts.find (value);
496
+ if (layoutIt == layouts.end ()) {
497
+ return value;
498
+ }
499
+ assert (layoutIt->second .encodings .size () == 1 &&
500
+ " we should have resolved to a single encoding" );
501
+ Attribute encodingPicked = *(layoutIt->second .encodings .begin ());
502
+ if (encodingPicked == tensorType.getEncoding ())
503
+ return value;
504
+ return rewriteMapping.at ({value, encodingPicked});
505
+ }
506
+
472
507
Value LayoutPropagation::getValueAs (Value value, Attribute encoding) {
473
508
if (auto tensorType = dyn_cast<RankedTensorType>(value.getType ())) {
474
- Value rewrittenValue;
475
- auto layoutIt = layouts.find (value);
476
- if (layoutIt == layouts.end ()) {
477
- rewrittenValue = value;
478
- } else {
479
- assert (layoutIt->second .encodings .size () == 1 &&
480
- " we should have resolved to a single encoding" );
481
- Attribute encodingPicked = *(layoutIt->second .encodings .begin ());
482
- if (encodingPicked == tensorType.getEncoding ())
483
- rewrittenValue = value;
484
- else
485
- rewrittenValue = rewriteMapping[{value, encodingPicked}];
486
- }
487
- assert (rewrittenValue);
509
+ Value rewrittenValue = getRewrittenValue (value);
488
510
if (cast<RankedTensorType>(rewrittenValue.getType ()).getEncoding () ==
489
511
encoding)
490
512
return rewrittenValue;
@@ -922,7 +944,7 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) {
922
944
}
923
945
if (op->hasTrait <OpTrait::SameOperandsAndResultEncoding>() ||
924
946
op->hasTrait <OpTrait::Elementwise>() ||
925
- isa<ReduceOp, ExpandDimsOp, ReshapeOp, TransOp, JoinOp, SplitOp,
947
+ isa<ReduceOp, ExpandDimsOp, ReshapeOp, TransOp, JoinOp, SplitOp, GatherOp,
926
948
ConvertLayoutOp>(op)) {
927
949
Operation *newOp = cloneElementwise (rewriter, op, encoding);
928
950
for (auto [oldResult, newResult] :
@@ -944,6 +966,9 @@ bool canBeRemat(Operation *op) {
944
966
return !ttgi::isExpensiveLoadOrStore (op);
945
967
if (isa<AtomicRMWOp, AtomicCASOp, DotOp>(op))
946
968
return false ;
969
+ if (auto gather = dyn_cast<GatherOp>(op))
970
+ return !gather.getEfficientLayout ();
971
+
947
972
if (isa<scf::WhileOp, scf::ConditionOp>(op))
948
973
return false ;
949
974
@@ -1211,8 +1236,8 @@ LogicalResult LayoutRematerialization::getRematerializableSlice(
1211
1236
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
1212
1237
DenseMap<Value, Attribute> &layout,
1213
1238
std::function<bool (Operation *)> stopPropagation) {
1214
- LogicalResult result = getConvertBackwardSlice (
1215
- root, rootEncoding, slice, layout, std::move ( stopPropagation) );
1239
+ LogicalResult result = getConvertBackwardSlice (root, rootEncoding, slice,
1240
+ layout, stopPropagation);
1216
1241
if (result.failed () || slice.empty ())
1217
1242
return failure ();
1218
1243
@@ -1226,13 +1251,13 @@ LogicalResult LayoutRematerialization::getRematerializableSlice(
1226
1251
return success ();
1227
1252
}
1228
1253
1229
- void LayoutRematerialization::hoistConvertIntoConditionals () {
1254
+ void LayoutRematerialization::backwardRematerialization () {
1230
1255
// Go through each ConvertLayoutOp.
1231
1256
SmallVector<ConvertLayoutOp> convertOps;
1232
1257
funcOp.walk (
1233
1258
[&](ConvertLayoutOp convertOp) { convertOps.push_back (convertOp); });
1234
1259
for (ConvertLayoutOp convertOp : convertOps) {
1235
- hoistConvertIntoConditionals (convertOp);
1260
+ backwardRematerialization (convertOp);
1236
1261
if (!opToDelete.contains (convertOp)) {
1237
1262
// If the conversion didn't get removed, consider it for reuse in future
1238
1263
// backward slices.
@@ -1242,13 +1267,13 @@ void LayoutRematerialization::hoistConvertIntoConditionals() {
1242
1267
}
1243
1268
}
1244
1269
1245
- void LayoutRematerialization::backwardRematerialization () {
1270
+ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast () {
1246
1271
// Go through each ConvertLayoutOp.
1247
1272
SmallVector<ConvertLayoutOp> convertOps;
1248
1273
funcOp.walk (
1249
1274
[&](ConvertLayoutOp convertOp) { convertOps.push_back (convertOp); });
1250
1275
for (ConvertLayoutOp convertOp : convertOps) {
1251
- backwardRematerialization (convertOp);
1276
+ hoistConvertOnTopOfExtOrBroadcast (convertOp);
1252
1277
if (!opToDelete.contains (convertOp)) {
1253
1278
// If the conversion didn't get removed, consider it for reuse in future
1254
1279
// backward slices.
@@ -1258,13 +1283,13 @@ void LayoutRematerialization::backwardRematerialization() {
1258
1283
}
1259
1284
}
1260
1285
1261
- void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast () {
1286
+ void LayoutRematerialization::hoistConvertIntoConditionals () {
1262
1287
// Go through each ConvertLayoutOp.
1263
1288
SmallVector<ConvertLayoutOp> convertOps;
1264
1289
funcOp.walk (
1265
1290
[&](ConvertLayoutOp convertOp) { convertOps.push_back (convertOp); });
1266
1291
for (ConvertLayoutOp convertOp : convertOps) {
1267
- hoistConvertOnTopOfExtOrBroadcast (convertOp);
1292
+ hoistConvertIntoConditionals (convertOp);
1268
1293
if (!opToDelete.contains (convertOp)) {
1269
1294
// If the conversion didn't get removed, consider it for reuse in future
1270
1295
// backward slices.
@@ -1274,6 +1299,40 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
1274
1299
}
1275
1300
}
1276
1301
1302
+ static bool isExpensiveMathOp (Operation *op) {
1303
+ // These operations are either multiple instructions or have throughput
1304
+ // lower than 16 according to the arithmetic instructions table in:
1305
+ // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#arithmetic-instructions
1306
+ return isa<arith::DivFOp, math::ErfcOp, math::SinhOp, math::CoshOp,
1307
+ math::TanhOp, math::AsinhOp, math::AcoshOp, math::AtanhOp,
1308
+ math::CtPopOp, math::CountLeadingZerosOp,
1309
+ math::CountTrailingZerosOp, math::ExpOp, math::Exp2Op,
1310
+ math::ExpM1Op, math::LogOp, math::Log2Op, math::Log10Op,
1311
+ math::Log1pOp, math::SinOp, math::CosOp, math::TanOp, math::AsinOp,
1312
+ math::AcosOp, math::AtanOp, math::Atan2Op, math::PowFOp,
1313
+ math::SqrtOp, math::RsqrtOp, math::ErfOp, math::CbrtOp>(op);
1314
+ }
1315
+
1316
+ static int64_t getByteCount (Value result, int64_t minElementCount = 0 ,
1317
+ int64_t minBitWidth = 0 ) {
1318
+ int64_t elementCount = 0 ;
1319
+ int64_t dtypeBitWidth = 0 ;
1320
+ if (auto tensorTy = dyn_cast<RankedTensorType>(result.getType ())) {
1321
+ elementCount = tensorTy.getNumElements ();
1322
+ auto elemType = tensorTy.getElementType ();
1323
+ if (elemType.isIntOrFloat ()) {
1324
+ dtypeBitWidth = elemType.getIntOrFloatBitWidth ();
1325
+ }
1326
+ }
1327
+ if (elementCount < minElementCount) {
1328
+ elementCount = minElementCount;
1329
+ }
1330
+ if (dtypeBitWidth < minBitWidth) {
1331
+ dtypeBitWidth = minBitWidth;
1332
+ }
1333
+ return (elementCount * dtypeBitWidth) >> 3 ;
1334
+ }
1335
+
1277
1336
void LayoutRematerialization::backwardRematerialization (
1278
1337
ConvertLayoutOp convertOp) {
1279
1338
RankedTensorType targetType = convertOp.getType ();
@@ -1373,30 +1432,32 @@ void LayoutRematerialization::hoistConvertDotOperand(
1373
1432
{ DBGS () << " Block arguments not supported. Got " << v << " \n " ; });
1374
1433
return ;
1375
1434
}
1376
- auto loadOp = dyn_cast<LoadOp>(v. getDefiningOp ());
1377
- // We expect the leaves of the slice to be Load or arith::Constant
1378
- // This could be generalised if necessary
1379
- if (!loadOp ) {
1435
+
1436
+ // We expect the leaves of the slice to be Load, DescriptorLoad or
1437
+ // arith::Constant This could be generalised if necessary
1438
+ if (!isa<LoadOp, DescriptorLoadOp>(v. getDefiningOp ()) ) {
1380
1439
auto op = v.getDefiningOp ();
1381
1440
if (isa<arith::ConstantOp>(op) || noDataMovement (op)) {
1382
1441
innerSlice.insert (v);
1383
1442
continue ;
1384
1443
} else {
1385
1444
LLVM_DEBUG ({
1386
- DBGS () << " Leaves must be Load or Constant. Got " << v << " \n " ;
1445
+ DBGS () << " Leaves must be Load, DescriptorLoad or Constant. Got "
1446
+ << v << " \n " ;
1387
1447
});
1388
1448
return ;
1389
1449
}
1390
1450
}
1451
+ Operation *loadOp = v.getDefiningOp ();
1391
1452
builder.setInsertionPointAfter (loadOp);
1392
- auto type = dyn_cast<RankedTensorType>(loadOp.getType ());
1453
+ auto type = dyn_cast<RankedTensorType>(loadOp-> getResult ( 0 ) .getType ());
1393
1454
if (!type)
1394
1455
continue ;
1395
1456
auto newType = RankedTensorType::get (type.getShape (), type.getElementType (),
1396
- layout[loadOp]);
1457
+ layout[loadOp-> getResult ( 0 ) ]);
1397
1458
auto newConvertOp = builder.create <ConvertLayoutOp>(
1398
- convertOp.getLoc (), newType, loadOp. getResult ());
1399
- mapping.map (loadOp. getResult (), newConvertOp.getResult ());
1459
+ convertOp.getLoc (), newType, loadOp-> getResult (0 ));
1460
+ mapping.map (loadOp-> getResult (0 ), newConvertOp.getResult ());
1400
1461
}
1401
1462
1402
1463
if (innerSlice.empty ()) {
@@ -1418,7 +1479,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
1418
1479
ConvertLayoutOp convertOp) {
1419
1480
// DotOperand is hoisted by hoistDotOperand
1420
1481
RankedTensorType targetType = convertOp.getType ();
1421
- if (mlir:: isa<DotOperandEncodingAttr>(targetType.getEncoding ()))
1482
+ if (isa<DotOperandEncodingAttr>(targetType.getEncoding ()))
1422
1483
return ;
1423
1484
1424
1485
auto isExtOrBroadcastOp = [](Operation *op) {
@@ -1641,6 +1702,7 @@ void hoistConvert(ModuleOp module) {
1641
1702
layoutRemat.cleanup ();
1642
1703
});
1643
1704
}
1705
+ } // namespace
1644
1706
1645
1707
class TritonIntelGPURemoveLayoutConversionsPass
1646
1708
: public triton::gpu::intel::impl::
@@ -1722,5 +1784,3 @@ class TritonIntelGPURemoveLayoutConversionsPass
1722
1784
});
1723
1785
}
1724
1786
};
1725
-
1726
- } // namespace
0 commit comments