Skip to content

Commit 553d01d

Browse files
authored
Re-land cost model with tweaks to avoid perf regression on GB200 (#6699)
Tested on both H100 and GB200 on our internal benchmarks
1 parent 8ebe58e commit 553d01d

File tree

1 file changed

+135
-1
lines changed

1 file changed

+135
-1
lines changed

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,40 @@ void LayoutRematerialization::hoistConvertIntoConditionals() {
10571057
}
10581058
}
10591059

1060+
static bool isExpensiveMathOp(Operation *op) {
1061+
// These operations are either multiple instructions or have throughput
1062+
// lower than 16 according to the arithmetic instructions table in:
1063+
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#arithmetic-instructions
1064+
return isa<arith::DivFOp, math::ErfcOp, math::SinhOp, math::CoshOp,
1065+
math::TanhOp, math::AsinhOp, math::AcoshOp, math::AtanhOp,
1066+
math::CtPopOp, math::CountLeadingZerosOp,
1067+
math::CountTrailingZerosOp, math::ExpOp, math::Exp2Op,
1068+
math::ExpM1Op, math::LogOp, math::Log2Op, math::Log10Op,
1069+
math::Log1pOp, math::SinOp, math::CosOp, math::TanOp, math::AsinOp,
1070+
math::AcosOp, math::AtanOp, math::Atan2Op, math::PowFOp,
1071+
math::SqrtOp, math::RsqrtOp, math::ErfOp, math::CbrtOp>(op);
1072+
}
1073+
1074+
static int64_t getByteCount(Value result, int64_t minElementCount = 0,
1075+
int64_t minBitWidth = 0) {
1076+
int64_t elementCount = 0;
1077+
int64_t dtypeBitWidth = 0;
1078+
if (auto tensorTy = dyn_cast<RankedTensorType>(result.getType())) {
1079+
elementCount = tensorTy.getNumElements();
1080+
auto elemType = tensorTy.getElementType();
1081+
if (elemType.isIntOrFloat()) {
1082+
dtypeBitWidth = elemType.getIntOrFloatBitWidth();
1083+
}
1084+
}
1085+
if (elementCount < minElementCount) {
1086+
elementCount = minElementCount;
1087+
}
1088+
if (dtypeBitWidth < minBitWidth) {
1089+
dtypeBitWidth = minBitWidth;
1090+
}
1091+
return (elementCount * dtypeBitWidth) >> 3;
1092+
}
1093+
10601094
void LayoutRematerialization::backwardRematerialization(
10611095
ConvertLayoutOp convertOp) {
10621096
// DotOperand is hoisted by hoistDotOperand
@@ -1088,12 +1122,112 @@ void LayoutRematerialization::backwardRematerialization(
10881122
return;
10891123
}
10901124

1125+
// 2. Determine whether rematerialisation is beneficial.
1126+
1127+
// Identify all operations in the slice
1128+
SetVector<Operation *> sliceOps;
1129+
for (Value v : slice) {
1130+
if (Operation *op = v.getDefiningOp()) {
1131+
sliceOps.insert(op);
1132+
}
1133+
}
1134+
1135+
// Compute single-use operations
1136+
DenseMap<Operation *, bool> isSingleUse;
1137+
std::function<bool(Operation *)> isOpSingleUse;
1138+
isOpSingleUse = [&](Operation *op) -> bool {
1139+
// lookup in memoization array:
1140+
auto it = isSingleUse.find(op);
1141+
if (it != isSingleUse.end()) {
1142+
return it->second;
1143+
}
1144+
1145+
bool singleUse = true;
1146+
1147+
for (Value result : op->getResults()) {
1148+
for (Operation *user : result.getUsers()) {
1149+
if (user == convertOp) {
1150+
continue;
1151+
}
1152+
if (sliceOps.contains(user)) {
1153+
if (!isOpSingleUse(user)) {
1154+
singleUse = false;
1155+
break;
1156+
}
1157+
} else {
1158+
singleUse = false;
1159+
break;
1160+
}
1161+
}
1162+
if (!singleUse) {
1163+
break;
1164+
}
1165+
}
1166+
1167+
// insert into memoization array:
1168+
isSingleUse[op] = singleUse;
1169+
return singleUse;
1170+
};
1171+
1172+
// Measure the number of bytes that we're manipulating with the
1173+
// ConvertLayoutOp. We pessimistically assume that we round-trip
1174+
// through shared memory and that we cannot vectorise sub-register
1175+
// loads/stores, so we set a minimum element count of 32 (the warp
1176+
// size and number of shared memory banks) and minimum bitwidth of
1177+
// 32 (the width per bank of the shared memory load/store unit).
1178+
int64_t convertLayoutBytes = getByteCount(convertOp.getSrc(), 32, 32);
1179+
1180+
// We measure costs in standardised milli-SM-cycles. The smem load
1181+
// and store each cost 8 * convertLayoutBytes, and then we double
1182+
// it to account for extra cost due to synchronisation.
1183+
int64_t convertLayoutCost = 32 * convertLayoutBytes;
1184+
int64_t rematerialisationCost = 0;
1185+
1186+
// Evaluate single-use status for every operation in slice
1187+
for (Operation *op : sliceOps) {
1188+
auto dialect = op->getDialect();
1189+
if (isOpSingleUse(op)) {
1190+
// when we rematerialise, this operation does not get duplicated
1191+
// so it does not contribute to our cost model:
1192+
continue;
1193+
} else if (isa<arith::ConstantOp>(op)) {
1194+
// special-case: arith.constant has zero cost
1195+
continue;
1196+
} else if (isa<LoadOp>(op)) {
1197+
// optimistically assume L1-cached:
1198+
for (Value result : op->getResults()) {
1199+
rematerialisationCost += 8 * getByteCount(result);
1200+
}
1201+
} else if (isa<arith::ArithDialect, math::MathDialect>(dialect)) {
1202+
// this is an arithmetic operation; we distinguish between cheap
1203+
// operations (such as floating point add/mul which can be fused
1204+
// as halves of a single-cycle FMA instruction) and expensive
1205+
// operations which use the special function unit and/or involve
1206+
// multiple instructions.
1207+
int64_t multiplier = isExpensiveMathOp(op) ? 8 : 1;
1208+
for (Value result : op->getResults()) {
1209+
rematerialisationCost += multiplier * getByteCount(result);
1210+
}
1211+
}
1212+
}
1213+
1214+
LLVM_DEBUG({
1215+
DBGS() << " convert layout cost: " << convertLayoutCost << "\n";
1216+
DBGS() << " rematerialisation cost: " << rematerialisationCost << "\n";
1217+
});
1218+
1219+
if (rematerialisationCost > convertLayoutCost) {
1220+
LDBG(" skipped rematerialization due to higher cost");
1221+
return;
1222+
}
1223+
10911224
LLVM_DEBUG({
10921225
DBGS() << " remat convert op " << convertOp << '\n';
10931226
for (Value v : slice)
10941227
DBGS() << " " << v << '\n';
10951228
});
1096-
// 2. Rewrite the slice.
1229+
1230+
// 3. Rewrite the slice.
10971231
rewriteSlice(slice, layout, convertOp);
10981232
}
10991233

0 commit comments

Comments
 (0)