@@ -1057,40 +1057,6 @@ void LayoutRematerialization::hoistConvertIntoConditionals() {
10571057 }
10581058}
10591059
1060- static bool isExpensiveMathOp (Operation *op) {
1061- // These operations are either multiple instructions or have throughput
1062- // lower than 16 according to the arithmetic instructions table in:
1063- // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#arithmetic-instructions
1064- return isa<arith::DivFOp, math::ErfcOp, math::SinhOp, math::CoshOp,
1065- math::TanhOp, math::AsinhOp, math::AcoshOp, math::AtanhOp,
1066- math::CtPopOp, math::CountLeadingZerosOp,
1067- math::CountTrailingZerosOp, math::ExpOp, math::Exp2Op,
1068- math::ExpM1Op, math::LogOp, math::Log2Op, math::Log10Op,
1069- math::Log1pOp, math::SinOp, math::CosOp, math::TanOp, math::AsinOp,
1070- math::AcosOp, math::AtanOp, math::Atan2Op, math::PowFOp,
1071- math::SqrtOp, math::RsqrtOp, math::ErfOp, math::CbrtOp>(op);
1072- }
1073-
1074- static int64_t getByteCount (Value result, int64_t minElementCount = 0 ,
1075- int64_t minBitWidth = 0 ) {
1076- int64_t elementCount = 0 ;
1077- int64_t dtypeBitWidth = 0 ;
1078- if (auto tensorTy = dyn_cast<RankedTensorType>(result.getType ())) {
1079- elementCount = tensorTy.getNumElements ();
1080- auto elemType = tensorTy.getElementType ();
1081- if (elemType.isIntOrFloat ()) {
1082- dtypeBitWidth = elemType.getIntOrFloatBitWidth ();
1083- }
1084- }
1085- if (elementCount < minElementCount) {
1086- elementCount = minElementCount;
1087- }
1088- if (dtypeBitWidth < minBitWidth) {
1089- dtypeBitWidth = minBitWidth;
1090- }
1091- return (elementCount * dtypeBitWidth) >> 3 ;
1092- }
1093-
10941060void LayoutRematerialization::backwardRematerialization (
10951061 ConvertLayoutOp convertOp) {
10961062 // DotOperand is hoisted by hoistDotOperand
@@ -1122,112 +1088,12 @@ void LayoutRematerialization::backwardRematerialization(
11221088 return ;
11231089 }
11241090
1125- // 2. Determine whether rematerialisation is beneficial.
1126-
1127- // Identify all operations in the slice
1128- SetVector<Operation *> sliceOps;
1129- for (Value v : slice) {
1130- if (Operation *op = v.getDefiningOp ()) {
1131- sliceOps.insert (op);
1132- }
1133- }
1134-
1135- // Compute single-use operations
1136- DenseMap<Operation *, bool > isSingleUse;
1137- std::function<bool (Operation *)> isOpSingleUse;
1138- isOpSingleUse = [&](Operation *op) -> bool {
1139- // lookup in memoization array:
1140- auto it = isSingleUse.find (op);
1141- if (it != isSingleUse.end ()) {
1142- return it->second ;
1143- }
1144-
1145- bool singleUse = true ;
1146-
1147- for (Value result : op->getResults ()) {
1148- for (Operation *user : result.getUsers ()) {
1149- if (user == convertOp) {
1150- continue ;
1151- }
1152- if (sliceOps.contains (user)) {
1153- if (!isOpSingleUse (user)) {
1154- singleUse = false ;
1155- break ;
1156- }
1157- } else {
1158- singleUse = false ;
1159- break ;
1160- }
1161- }
1162- if (!singleUse) {
1163- break ;
1164- }
1165- }
1166-
1167- // insert into memoization array:
1168- isSingleUse[op] = singleUse;
1169- return singleUse;
1170- };
1171-
1172- // Measure the number of bytes that we're manipulating with the
1173- // ConvertLayoutOp. We pessimistically assume that we round-trip
1174- // through shared memory and that we cannot vectorise sub-register
1175- // loads/stores, so we set a minimum element count of 32 (the warp
1176- // size and number of shared memory banks) and minimum bitwidth of
1177- // 32 (the width per bank of the shared memory load/store unit).
1178- int64_t convertLayoutBytes = getByteCount (convertOp.getSrc (), 32 , 32 );
1179-
1180- // We measure costs in standardised milli-SM-cycles. This gives:
1181- // smem load/store: 8 * byte count
1182- // synchronisation: 1024 (assuming 4 warps per block)
1183- int64_t convertLayoutCost = 16 * convertLayoutBytes + 1024 ;
1184- int64_t rematerialisationCost = 0 ;
1185-
1186- // Evaluate single-use status for every operation in slice
1187- for (Operation *op : sliceOps) {
1188- auto dialect = op->getDialect ();
1189- if (isOpSingleUse (op)) {
1190- // when we rematerialise, this operation does not get duplicated
1191- // so it does not contribute to our cost model:
1192- continue ;
1193- } else if (isa<arith::ConstantOp>(op)) {
1194- // special-case: arith.constant has zero cost
1195- continue ;
1196- } else if (isa<LoadOp>(op)) {
1197- // optimistically assume L1-cached:
1198- for (Value result : op->getResults ()) {
1199- rematerialisationCost += 8 * getByteCount (result);
1200- }
1201- } else if (isa<arith::ArithDialect, math::MathDialect>(dialect)) {
1202- // this is an arithmetic operation; we distinguish between cheap
1203- // operations (such as floating point add/mul which can be fused
1204- // as halves of a single-cycle FMA instruction) and expensive
1205- // operations which use the special function unit and/or involve
1206- // multiple instructions.
1207- int64_t multiplier = isExpensiveMathOp (op) ? 8 : 1 ;
1208- for (Value result : op->getResults ()) {
1209- rematerialisationCost += multiplier * getByteCount (result);
1210- }
1211- }
1212- }
1213-
1214- LLVM_DEBUG ({
1215- DBGS () << " convert layout cost: " << convertLayoutCost << " \n " ;
1216- DBGS () << " rematerialisation cost: " << rematerialisationCost << " \n " ;
1217- });
1218-
1219- if (rematerialisationCost > convertLayoutCost) {
1220- LDBG (" skipped rematerialization due to higher cost" );
1221- return ;
1222- }
1223-
12241091 LLVM_DEBUG ({
12251092 DBGS () << " remat convert op " << convertOp << ' \n ' ;
12261093 for (Value v : slice)
12271094 DBGS () << " " << v << ' \n ' ;
12281095 });
1229-
1230- // 3. Rewrite the slice.
1096+ // 2. Rewrite the slice.
12311097 rewriteSlice (slice, layout, convertOp);
12321098}
12331099
0 commit comments