33#include " mlir/IR/BuiltinAttributes.h"
44#include " mlir/IR/Dominance.h"
55#include " mlir/IR/IRMapping.h"
6+ #include " mlir/IR/Matchers.h"
67#include " mlir/IR/PatternMatch.h"
78#include " mlir/IR/Verifier.h"
9+ #include " mlir/Interfaces/InferTypeOpInterface.h"
810#include " mlir/Interfaces/SideEffectInterfaces.h"
11+ #include " mlir/Pass/Pass.h"
912#include " mlir/Pass/PassManager.h"
1013#include " mlir/Support/LogicalResult.h"
1114#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
15+ #include " mlir/Transforms/Passes.h"
16+ #include " mlir/Transforms/RegionUtils.h"
1217
1318#include " intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
1419#include " intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"
1722
1823#include " triton/Analysis/Utility.h"
1924#include " triton/Dialect/Triton/IR/Dialect.h"
20- #include " triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
25+ #include " triton/Dialect/TritonGPU/IR/Dialect.h"
26+ #include " triton/Dialect/TritonGPU/Transforms/Passes.h"
27+ #include " triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
2128#include " triton/Dialect/TritonGPU/Transforms/Utility.h"
2229#include < deque>
2330
@@ -106,6 +113,8 @@ class LayoutPropagation {
106113 // Return the mapped value in the given encoding. This will insert a convert
107114 // if the encoding is different than the encoding decided at resolve time.
108115 Value getValueAs (Value value, Attribute encoding);
116+ // Return the original value mapped to the new desired encoding.
117+ Value getRewrittenValue (Value value);
109118 // Dump the current stage of layout information.
110119 void dump ();
111120
@@ -190,7 +199,7 @@ bool isLayoutAnchor(Operation *op) {
190199 return ttgi::isExpensiveLoadOrStore (op);
191200 // TODO: we should estimate the cost of the not propagating layout for
192201 // AtomicCAS for further performance consideration.
193- if (isa<DotOp, AtomicCASOp>(op))
202+ if (isa<DotOp, DotScaledOp, AtomicCASOp>(op))
194203 return true ;
195204 if (isa<AtomicRMWOp>(op))
196205 if (auto tensorType =
@@ -304,6 +313,15 @@ SmallVector<Value> LayoutPropagation::propagateToUsers(Value value,
304313 setEncoding (user->getResults (), info, changed, user);
305314 continue ;
306315 }
316+ if (auto gatherOp = dyn_cast<GatherOp>(user)) {
317+ // Propagate the layout through the indices only, and if the layout does
318+ // not have an efficient layout set.
319+ if (!gatherOp.getEfficientLayout () &&
320+ &use == &gatherOp.getIndicesMutable ()) {
321+ setEncoding (gatherOp.getResult (), info, changed, user);
322+ continue ;
323+ }
324+ }
307325 if (auto storeOp = dyn_cast<StoreOp>(user)) {
308326 auto checkMMAorMMADerived = [](Attribute encoding) {
309327 bool isMMAorMMADerived = isa<MmaEncodingTrait>(encoding);
@@ -339,7 +357,7 @@ SmallVector<Value> LayoutPropagation::propagateToUsers(Value value,
339357
340358void LayoutPropagation::propagateLayout () {
341359 SmallVector<Value> queue;
342- for (const auto & it : layouts) {
360+ for (auto it : layouts) {
343361 queue.push_back (it.first );
344362 }
345363 while (!queue.empty ()) {
@@ -353,6 +371,7 @@ void LayoutPropagation::propagateLayout() {
353371 << info.encodings .size () << " candidate encoding(s):\n " ;
354372 for (Attribute encoding : info.encodings )
355373 DBGS () << " " << encoding << " \n " ;
374+ DBGS () << " changed: " << changed.size () << " \n " ;
356375 });
357376
358377 queue.insert (queue.end (), changed.begin (), changed.end ());
@@ -469,22 +488,25 @@ void LayoutPropagation::map(Value old, Value newV) {
469488 newV;
470489}
471490
491+ Value LayoutPropagation::getRewrittenValue (Value value) {
492+ auto tensorType = dyn_cast<RankedTensorType>(value.getType ());
493+ if (!tensorType)
494+ return value;
495+ auto layoutIt = layouts.find (value);
496+ if (layoutIt == layouts.end ()) {
497+ return value;
498+ }
499+ assert (layoutIt->second .encodings .size () == 1 &&
500+ " we should have resolved to a single encoding" );
501+ Attribute encodingPicked = *(layoutIt->second .encodings .begin ());
502+ if (encodingPicked == tensorType.getEncoding ())
503+ return value;
504+ return rewriteMapping.at ({value, encodingPicked});
505+ }
506+
472507Value LayoutPropagation::getValueAs (Value value, Attribute encoding) {
473508 if (auto tensorType = dyn_cast<RankedTensorType>(value.getType ())) {
474- Value rewrittenValue;
475- auto layoutIt = layouts.find (value);
476- if (layoutIt == layouts.end ()) {
477- rewrittenValue = value;
478- } else {
479- assert (layoutIt->second .encodings .size () == 1 &&
480- " we should have resolved to a single encoding" );
481- Attribute encodingPicked = *(layoutIt->second .encodings .begin ());
482- if (encodingPicked == tensorType.getEncoding ())
483- rewrittenValue = value;
484- else
485- rewrittenValue = rewriteMapping[{value, encodingPicked}];
486- }
487- assert (rewrittenValue);
509+ Value rewrittenValue = getRewrittenValue (value);
488510 if (cast<RankedTensorType>(rewrittenValue.getType ()).getEncoding () ==
489511 encoding)
490512 return rewrittenValue;
@@ -922,7 +944,7 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) {
922944 }
923945 if (op->hasTrait <OpTrait::SameOperandsAndResultEncoding>() ||
924946 op->hasTrait <OpTrait::Elementwise>() ||
925- isa<ReduceOp, ExpandDimsOp, ReshapeOp, TransOp, JoinOp, SplitOp,
947+ isa<ReduceOp, ExpandDimsOp, ReshapeOp, TransOp, JoinOp, SplitOp, GatherOp,
926948 ConvertLayoutOp>(op)) {
927949 Operation *newOp = cloneElementwise (rewriter, op, encoding);
928950 for (auto [oldResult, newResult] :
@@ -944,6 +966,9 @@ bool canBeRemat(Operation *op) {
944966 return !ttgi::isExpensiveLoadOrStore (op);
945967 if (isa<AtomicRMWOp, AtomicCASOp, DotOp>(op))
946968 return false ;
969+ if (auto gather = dyn_cast<GatherOp>(op))
970+ return !gather.getEfficientLayout ();
971+
947972 if (isa<scf::WhileOp, scf::ConditionOp>(op))
948973 return false ;
949974
@@ -1211,8 +1236,8 @@ LogicalResult LayoutRematerialization::getRematerializableSlice(
12111236 OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
12121237 DenseMap<Value, Attribute> &layout,
12131238 std::function<bool (Operation *)> stopPropagation) {
1214- LogicalResult result = getConvertBackwardSlice (
1215- root, rootEncoding, slice, layout, std::move ( stopPropagation) );
1239+ LogicalResult result = getConvertBackwardSlice (root, rootEncoding, slice,
1240+ layout, stopPropagation);
12161241 if (result.failed () || slice.empty ())
12171242 return failure ();
12181243
@@ -1226,13 +1251,13 @@ LogicalResult LayoutRematerialization::getRematerializableSlice(
12261251 return success ();
12271252}
12281253
1229- void LayoutRematerialization::hoistConvertIntoConditionals () {
1254+ void LayoutRematerialization::backwardRematerialization () {
12301255 // Go through each ConvertLayoutOp.
12311256 SmallVector<ConvertLayoutOp> convertOps;
12321257 funcOp.walk (
12331258 [&](ConvertLayoutOp convertOp) { convertOps.push_back (convertOp); });
12341259 for (ConvertLayoutOp convertOp : convertOps) {
1235- hoistConvertIntoConditionals (convertOp);
1260+ backwardRematerialization (convertOp);
12361261 if (!opToDelete.contains (convertOp)) {
12371262 // If the conversion didn't get removed, consider it for reuse in future
12381263 // backward slices.
@@ -1242,13 +1267,13 @@ void LayoutRematerialization::hoistConvertIntoConditionals() {
12421267 }
12431268}
12441269
1245- void LayoutRematerialization::backwardRematerialization () {
1270+ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast () {
12461271 // Go through each ConvertLayoutOp.
12471272 SmallVector<ConvertLayoutOp> convertOps;
12481273 funcOp.walk (
12491274 [&](ConvertLayoutOp convertOp) { convertOps.push_back (convertOp); });
12501275 for (ConvertLayoutOp convertOp : convertOps) {
1251- backwardRematerialization (convertOp);
1276+ hoistConvertOnTopOfExtOrBroadcast (convertOp);
12521277 if (!opToDelete.contains (convertOp)) {
12531278 // If the conversion didn't get removed, consider it for reuse in future
12541279 // backward slices.
@@ -1258,13 +1283,13 @@ void LayoutRematerialization::backwardRematerialization() {
12581283 }
12591284}
12601285
1261- void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast () {
1286+ void LayoutRematerialization::hoistConvertIntoConditionals () {
12621287 // Go through each ConvertLayoutOp.
12631288 SmallVector<ConvertLayoutOp> convertOps;
12641289 funcOp.walk (
12651290 [&](ConvertLayoutOp convertOp) { convertOps.push_back (convertOp); });
12661291 for (ConvertLayoutOp convertOp : convertOps) {
1267- hoistConvertOnTopOfExtOrBroadcast (convertOp);
1292+ hoistConvertIntoConditionals (convertOp);
12681293 if (!opToDelete.contains (convertOp)) {
12691294 // If the conversion didn't get removed, consider it for reuse in future
12701295 // backward slices.
@@ -1274,6 +1299,40 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
12741299 }
12751300}
12761301
1302+ static bool isExpensiveMathOp (Operation *op) {
1303+ // These operations are either multiple instructions or have throughput
1304+ // lower than 16 according to the arithmetic instructions table in:
1305+ // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#arithmetic-instructions
1306+ return isa<arith::DivFOp, math::ErfcOp, math::SinhOp, math::CoshOp,
1307+ math::TanhOp, math::AsinhOp, math::AcoshOp, math::AtanhOp,
1308+ math::CtPopOp, math::CountLeadingZerosOp,
1309+ math::CountTrailingZerosOp, math::ExpOp, math::Exp2Op,
1310+ math::ExpM1Op, math::LogOp, math::Log2Op, math::Log10Op,
1311+ math::Log1pOp, math::SinOp, math::CosOp, math::TanOp, math::AsinOp,
1312+ math::AcosOp, math::AtanOp, math::Atan2Op, math::PowFOp,
1313+ math::SqrtOp, math::RsqrtOp, math::ErfOp, math::CbrtOp>(op);
1314+ }
1315+
1316+ static int64_t getByteCount (Value result, int64_t minElementCount = 0 ,
1317+ int64_t minBitWidth = 0 ) {
1318+ int64_t elementCount = 0 ;
1319+ int64_t dtypeBitWidth = 0 ;
1320+ if (auto tensorTy = dyn_cast<RankedTensorType>(result.getType ())) {
1321+ elementCount = tensorTy.getNumElements ();
1322+ auto elemType = tensorTy.getElementType ();
1323+ if (elemType.isIntOrFloat ()) {
1324+ dtypeBitWidth = elemType.getIntOrFloatBitWidth ();
1325+ }
1326+ }
1327+ if (elementCount < minElementCount) {
1328+ elementCount = minElementCount;
1329+ }
1330+ if (dtypeBitWidth < minBitWidth) {
1331+ dtypeBitWidth = minBitWidth;
1332+ }
1333+ return (elementCount * dtypeBitWidth) >> 3 ;
1334+ }
1335+
12771336void LayoutRematerialization::backwardRematerialization (
12781337 ConvertLayoutOp convertOp) {
12791338 RankedTensorType targetType = convertOp.getType ();
@@ -1373,30 +1432,32 @@ void LayoutRematerialization::hoistConvertDotOperand(
13731432 { DBGS () << " Block arguments not supported. Got " << v << " \n " ; });
13741433 return ;
13751434 }
1376- auto loadOp = dyn_cast<LoadOp>(v. getDefiningOp ());
1377- // We expect the leaves of the slice to be Load or arith::Constant
1378- // This could be generalised if necessary
1379- if (!loadOp ) {
1435+
1436+ // We expect the leaves of the slice to be Load, DescriptorLoad or
1437+ // arith::Constant This could be generalised if necessary
1438+ if (!isa<LoadOp, DescriptorLoadOp>(v. getDefiningOp ()) ) {
13801439 auto op = v.getDefiningOp ();
13811440 if (isa<arith::ConstantOp>(op) || noDataMovement (op)) {
13821441 innerSlice.insert (v);
13831442 continue ;
13841443 } else {
13851444 LLVM_DEBUG ({
1386- DBGS () << " Leaves must be Load or Constant. Got " << v << " \n " ;
1445+ DBGS () << " Leaves must be Load, DescriptorLoad or Constant. Got "
1446+ << v << " \n " ;
13871447 });
13881448 return ;
13891449 }
13901450 }
1451+ Operation *loadOp = v.getDefiningOp ();
13911452 builder.setInsertionPointAfter (loadOp);
1392- auto type = dyn_cast<RankedTensorType>(loadOp.getType ());
1453+ auto type = dyn_cast<RankedTensorType>(loadOp-> getResult ( 0 ) .getType ());
13931454 if (!type)
13941455 continue ;
13951456 auto newType = RankedTensorType::get (type.getShape (), type.getElementType (),
1396- layout[loadOp]);
1457+ layout[loadOp-> getResult ( 0 ) ]);
13971458 auto newConvertOp = builder.create <ConvertLayoutOp>(
1398- convertOp.getLoc (), newType, loadOp. getResult ());
1399- mapping.map (loadOp. getResult (), newConvertOp.getResult ());
1459+ convertOp.getLoc (), newType, loadOp-> getResult (0 ));
1460+ mapping.map (loadOp-> getResult (0 ), newConvertOp.getResult ());
14001461 }
14011462
14021463 if (innerSlice.empty ()) {
@@ -1418,7 +1479,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
14181479 ConvertLayoutOp convertOp) {
14191480 // DotOperand is hoisted by hoistDotOperand
14201481 RankedTensorType targetType = convertOp.getType ();
1421- if (mlir:: isa<DotOperandEncodingAttr>(targetType.getEncoding ()))
1482+ if (isa<DotOperandEncodingAttr>(targetType.getEncoding ()))
14221483 return ;
14231484
14241485 auto isExtOrBroadcastOp = [](Operation *op) {
@@ -1641,6 +1702,7 @@ void hoistConvert(ModuleOp module) {
16411702 layoutRemat.cleanup ();
16421703 });
16431704}
1705+ } // namespace
16441706
16451707class TritonIntelGPURemoveLayoutConversionsPass
16461708 : public triton::gpu::intel::impl::
@@ -1722,5 +1784,3 @@ class TritonIntelGPURemoveLayoutConversionsPass
17221784 });
17231785 }
17241786};
1725-
1726- } // namespace
0 commit comments