@@ -1057,6 +1057,40 @@ void LayoutRematerialization::hoistConvertIntoConditionals() {
10571057 }
10581058}
10591059
1060+ static bool isExpensiveMathOp (Operation *op) {
1061+ // These operations are either multiple instructions or have throughput
1062+ // lower than 16 according to the arithmetic instructions table in:
1063+ // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#arithmetic-instructions
1064+ return isa<arith::DivFOp, math::ErfcOp, math::SinhOp, math::CoshOp,
1065+ math::TanhOp, math::AsinhOp, math::AcoshOp, math::AtanhOp,
1066+ math::CtPopOp, math::CountLeadingZerosOp,
1067+ math::CountTrailingZerosOp, math::ExpOp, math::Exp2Op,
1068+ math::ExpM1Op, math::LogOp, math::Log2Op, math::Log10Op,
1069+ math::Log1pOp, math::SinOp, math::CosOp, math::TanOp, math::AsinOp,
1070+ math::AcosOp, math::AtanOp, math::Atan2Op, math::PowFOp,
1071+ math::SqrtOp, math::RsqrtOp, math::ErfOp, math::CbrtOp>(op);
1072+ }
1073+
1074+ static int64_t getByteCount (Value result, int64_t minElementCount = 0 ,
1075+ int64_t minBitWidth = 0 ) {
1076+ int64_t elementCount = 0 ;
1077+ int64_t dtypeBitWidth = 0 ;
1078+ if (auto tensorTy = dyn_cast<RankedTensorType>(result.getType ())) {
1079+ elementCount = tensorTy.getNumElements ();
1080+ auto elemType = tensorTy.getElementType ();
1081+ if (elemType.isIntOrFloat ()) {
1082+ dtypeBitWidth = elemType.getIntOrFloatBitWidth ();
1083+ }
1084+ }
1085+ if (elementCount < minElementCount) {
1086+ elementCount = minElementCount;
1087+ }
1088+ if (dtypeBitWidth < minBitWidth) {
1089+ dtypeBitWidth = minBitWidth;
1090+ }
1091+ return (elementCount * dtypeBitWidth) >> 3 ;
1092+ }
1093+
10601094void LayoutRematerialization::backwardRematerialization (
10611095 ConvertLayoutOp convertOp) {
10621096 // DotOperand is hoisted by hoistDotOperand
@@ -1088,12 +1122,112 @@ void LayoutRematerialization::backwardRematerialization(
10881122 return ;
10891123 }
10901124
1125+ // 2. Determine whether rematerialisation is beneficial.
1126+
1127+ // Identify all operations in the slice
1128+ SetVector<Operation *> sliceOps;
1129+ for (Value v : slice) {
1130+ if (Operation *op = v.getDefiningOp ()) {
1131+ sliceOps.insert (op);
1132+ }
1133+ }
1134+
1135+ // Compute single-use operations
1136+ DenseMap<Operation *, bool > isSingleUse;
1137+ std::function<bool (Operation *)> isOpSingleUse;
1138+ isOpSingleUse = [&](Operation *op) -> bool {
1139+ // lookup in memoization array:
1140+ auto it = isSingleUse.find (op);
1141+ if (it != isSingleUse.end ()) {
1142+ return it->second ;
1143+ }
1144+
1145+ bool singleUse = true ;
1146+
1147+ for (Value result : op->getResults ()) {
1148+ for (Operation *user : result.getUsers ()) {
1149+ if (user == convertOp) {
1150+ continue ;
1151+ }
1152+ if (sliceOps.contains (user)) {
1153+ if (!isOpSingleUse (user)) {
1154+ singleUse = false ;
1155+ break ;
1156+ }
1157+ } else {
1158+ singleUse = false ;
1159+ break ;
1160+ }
1161+ }
1162+ if (!singleUse) {
1163+ break ;
1164+ }
1165+ }
1166+
1167+ // insert into memoization array:
1168+ isSingleUse[op] = singleUse;
1169+ return singleUse;
1170+ };
1171+
1172+ // Measure the number of bytes that we're manipulating with the
1173+ // ConvertLayoutOp. We pessimistically assume that we round-trip
1174+ // through shared memory and that we cannot vectorise sub-register
1175+ // loads/stores, so we set a minimum element count of 32 (the warp
1176+ // size and number of shared memory banks) and minimum bitwidth of
1177+ // 32 (the width per bank of the shared memory load/store unit).
1178+ int64_t convertLayoutBytes = getByteCount (convertOp.getSrc (), 32 , 32 );
1179+
1180+ // We measure costs in standardised milli-SM-cycles. The smem load
1181+ // and store each cost 8 * convertLayoutBytes, and then we double
1182+ // it to account for extra cost due to synchronisation.
1183+ int64_t convertLayoutCost = 32 * convertLayoutBytes;
1184+ int64_t rematerialisationCost = 0 ;
1185+
1186+ // Evaluate single-use status for every operation in slice
1187+ for (Operation *op : sliceOps) {
1188+ auto dialect = op->getDialect ();
1189+ if (isOpSingleUse (op)) {
1190+ // when we rematerialise, this operation does not get duplicated
1191+ // so it does not contribute to our cost model:
1192+ continue ;
1193+ } else if (isa<arith::ConstantOp>(op)) {
1194+ // special-case: arith.constant has zero cost
1195+ continue ;
1196+ } else if (isa<LoadOp>(op)) {
1197+ // optimistically assume L1-cached:
1198+ for (Value result : op->getResults ()) {
1199+ rematerialisationCost += 8 * getByteCount (result);
1200+ }
1201+ } else if (isa<arith::ArithDialect, math::MathDialect>(dialect)) {
1202+ // this is an arithmetic operation; we distinguish between cheap
1203+ // operations (such as floating point add/mul which can be fused
1204+ // as halves of a single-cycle FMA instruction) and expensive
1205+ // operations which use the special function unit and/or involve
1206+ // multiple instructions.
1207+ int64_t multiplier = isExpensiveMathOp (op) ? 8 : 1 ;
1208+ for (Value result : op->getResults ()) {
1209+ rematerialisationCost += multiplier * getByteCount (result);
1210+ }
1211+ }
1212+ }
1213+
1214+ LLVM_DEBUG ({
1215+ DBGS () << " convert layout cost: " << convertLayoutCost << " \n " ;
1216+ DBGS () << " rematerialisation cost: " << rematerialisationCost << " \n " ;
1217+ });
1218+
1219+ if (rematerialisationCost > convertLayoutCost) {
1220+ LDBG (" skipped rematerialization due to higher cost" );
1221+ return ;
1222+ }
1223+
10911224 LLVM_DEBUG ({
10921225 DBGS () << " remat convert op " << convertOp << ' \n ' ;
10931226 for (Value v : slice)
10941227 DBGS () << " " << v << ' \n ' ;
10951228 });
1096- // 2. Rewrite the slice.
1229+
1230+ // 3. Rewrite the slice.
10971231 rewriteSlice (slice, layout, convertOp);
10981232}
10991233
@@ -1179,30 +1313,32 @@ void LayoutRematerialization::hoistConvertDotOperand(
11791313 { DBGS () << " Block arguments not supported. Got " << v << " \n " ; });
11801314 return ;
11811315 }
1182- auto loadOp = dyn_cast<LoadOp>(v. getDefiningOp ());
1183- // We expect the leaves of the slice to be Load or arith::Constant
1184- // This could be generalised if necessary
1185- if (!loadOp ) {
1316+
1317+ // We expect the leaves of the slice to be Load, DescriptorLoad or
1318+ // arith::Constant This could be generalised if necessary
1319+ if (!isa<LoadOp, DescriptorLoadOp>(v. getDefiningOp ()) ) {
11861320 auto op = v.getDefiningOp ();
11871321 if (isa<arith::ConstantOp>(op) || noDataMovement (op)) {
11881322 innerSlice.insert (v);
11891323 continue ;
11901324 } else {
11911325 LLVM_DEBUG ({
1192- DBGS () << " Leaves must be Load or Constant. Got " << v << " \n " ;
1326+ DBGS () << " Leaves must be Load, DescriptorLoad or Constant. Got "
1327+ << v << " \n " ;
11931328 });
11941329 return ;
11951330 }
11961331 }
1332+ Operation *loadOp = v.getDefiningOp ();
11971333 builder.setInsertionPointAfter (loadOp);
1198- auto type = dyn_cast<RankedTensorType>(loadOp.getType ());
1334+ auto type = dyn_cast<RankedTensorType>(loadOp-> getResult ( 0 ) .getType ());
11991335 if (!type)
12001336 continue ;
12011337 auto newType = RankedTensorType::get (type.getShape (), type.getElementType (),
1202- layout[loadOp]);
1338+ layout[loadOp-> getResult ( 0 ) ]);
12031339 auto newConvertOp = builder.create <ConvertLayoutOp>(
1204- convertOp.getLoc (), newType, loadOp. getResult ());
1205- mapping.map (loadOp. getResult (), newConvertOp.getResult ());
1340+ convertOp.getLoc (), newType, loadOp-> getResult (0 ));
1341+ mapping.map (loadOp-> getResult (0 ), newConvertOp.getResult ());
12061342 }
12071343
12081344 if (innerSlice.empty ()) {
0 commit comments