Skip to content

Commit 8ebe58e

Browse files
authored
Revert "[BACKEND] backwardRematerialization cost model" (#6696)
Reverts triton-lang/triton#6667 This is causing a regression in an internal OAI workload
1 parent 988d388 commit 8ebe58e

File tree

1 file changed

+1
-135
lines changed

1 file changed

+1
-135
lines changed

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 1 addition & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,40 +1057,6 @@ void LayoutRematerialization::hoistConvertIntoConditionals() {
10571057
}
10581058
}
10591059

1060-
static bool isExpensiveMathOp(Operation *op) {
1061-
// These operations are either multiple instructions or have throughput
1062-
// lower than 16 according to the arithmetic instructions table in:
1063-
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#arithmetic-instructions
1064-
return isa<arith::DivFOp, math::ErfcOp, math::SinhOp, math::CoshOp,
1065-
math::TanhOp, math::AsinhOp, math::AcoshOp, math::AtanhOp,
1066-
math::CtPopOp, math::CountLeadingZerosOp,
1067-
math::CountTrailingZerosOp, math::ExpOp, math::Exp2Op,
1068-
math::ExpM1Op, math::LogOp, math::Log2Op, math::Log10Op,
1069-
math::Log1pOp, math::SinOp, math::CosOp, math::TanOp, math::AsinOp,
1070-
math::AcosOp, math::AtanOp, math::Atan2Op, math::PowFOp,
1071-
math::SqrtOp, math::RsqrtOp, math::ErfOp, math::CbrtOp>(op);
1072-
}
1073-
1074-
static int64_t getByteCount(Value result, int64_t minElementCount = 0,
1075-
int64_t minBitWidth = 0) {
1076-
int64_t elementCount = 0;
1077-
int64_t dtypeBitWidth = 0;
1078-
if (auto tensorTy = dyn_cast<RankedTensorType>(result.getType())) {
1079-
elementCount = tensorTy.getNumElements();
1080-
auto elemType = tensorTy.getElementType();
1081-
if (elemType.isIntOrFloat()) {
1082-
dtypeBitWidth = elemType.getIntOrFloatBitWidth();
1083-
}
1084-
}
1085-
if (elementCount < minElementCount) {
1086-
elementCount = minElementCount;
1087-
}
1088-
if (dtypeBitWidth < minBitWidth) {
1089-
dtypeBitWidth = minBitWidth;
1090-
}
1091-
return (elementCount * dtypeBitWidth) >> 3;
1092-
}
1093-
10941060
void LayoutRematerialization::backwardRematerialization(
10951061
ConvertLayoutOp convertOp) {
10961062
// DotOperand is hoisted by hoistDotOperand
@@ -1122,112 +1088,12 @@ void LayoutRematerialization::backwardRematerialization(
11221088
return;
11231089
}
11241090

1125-
// 2. Determine whether rematerialisation is beneficial.
1126-
1127-
// Identify all operations in the slice
1128-
SetVector<Operation *> sliceOps;
1129-
for (Value v : slice) {
1130-
if (Operation *op = v.getDefiningOp()) {
1131-
sliceOps.insert(op);
1132-
}
1133-
}
1134-
1135-
// Compute single-use operations
1136-
DenseMap<Operation *, bool> isSingleUse;
1137-
std::function<bool(Operation *)> isOpSingleUse;
1138-
isOpSingleUse = [&](Operation *op) -> bool {
1139-
// lookup in memoization array:
1140-
auto it = isSingleUse.find(op);
1141-
if (it != isSingleUse.end()) {
1142-
return it->second;
1143-
}
1144-
1145-
bool singleUse = true;
1146-
1147-
for (Value result : op->getResults()) {
1148-
for (Operation *user : result.getUsers()) {
1149-
if (user == convertOp) {
1150-
continue;
1151-
}
1152-
if (sliceOps.contains(user)) {
1153-
if (!isOpSingleUse(user)) {
1154-
singleUse = false;
1155-
break;
1156-
}
1157-
} else {
1158-
singleUse = false;
1159-
break;
1160-
}
1161-
}
1162-
if (!singleUse) {
1163-
break;
1164-
}
1165-
}
1166-
1167-
// insert into memoization array:
1168-
isSingleUse[op] = singleUse;
1169-
return singleUse;
1170-
};
1171-
1172-
// Measure the number of bytes that we're manipulating with the
1173-
// ConvertLayoutOp. We pessimistically assume that we round-trip
1174-
// through shared memory and that we cannot vectorise sub-register
1175-
// loads/stores, so we set a minimum element count of 32 (the warp
1176-
// size and number of shared memory banks) and minimum bitwidth of
1177-
// 32 (the width per bank of the shared memory load/store unit).
1178-
int64_t convertLayoutBytes = getByteCount(convertOp.getSrc(), 32, 32);
1179-
1180-
// We measure costs in standardised milli-SM-cycles. This gives:
1181-
// smem load/store: 8 * byte count
1182-
// synchronisation: 1024 (assuming 4 warps per block)
1183-
int64_t convertLayoutCost = 16 * convertLayoutBytes + 1024;
1184-
int64_t rematerialisationCost = 0;
1185-
1186-
// Evaluate single-use status for every operation in slice
1187-
for (Operation *op : sliceOps) {
1188-
auto dialect = op->getDialect();
1189-
if (isOpSingleUse(op)) {
1190-
// when we rematerialise, this operation does not get duplicated
1191-
// so it does not contribute to our cost model:
1192-
continue;
1193-
} else if (isa<arith::ConstantOp>(op)) {
1194-
// special-case: arith.constant has zero cost
1195-
continue;
1196-
} else if (isa<LoadOp>(op)) {
1197-
// optimistically assume L1-cached:
1198-
for (Value result : op->getResults()) {
1199-
rematerialisationCost += 8 * getByteCount(result);
1200-
}
1201-
} else if (isa<arith::ArithDialect, math::MathDialect>(dialect)) {
1202-
// this is an arithmetic operation; we distinguish between cheap
1203-
// operations (such as floating point add/mul which can be fused
1204-
// as halves of a single-cycle FMA instruction) and expensive
1205-
// operations which use the special function unit and/or involve
1206-
// multiple instructions.
1207-
int64_t multiplier = isExpensiveMathOp(op) ? 8 : 1;
1208-
for (Value result : op->getResults()) {
1209-
rematerialisationCost += multiplier * getByteCount(result);
1210-
}
1211-
}
1212-
}
1213-
1214-
LLVM_DEBUG({
1215-
DBGS() << " convert layout cost: " << convertLayoutCost << "\n";
1216-
DBGS() << " rematerialisation cost: " << rematerialisationCost << "\n";
1217-
});
1218-
1219-
if (rematerialisationCost > convertLayoutCost) {
1220-
LDBG(" skipped rematerialization due to higher cost");
1221-
return;
1222-
}
1223-
12241091
LLVM_DEBUG({
12251092
DBGS() << " remat convert op " << convertOp << '\n';
12261093
for (Value v : slice)
12271094
DBGS() << " " << v << '\n';
12281095
});
1229-
1230-
// 3. Rewrite the slice.
1096+
// 2. Rewrite the slice.
12311097
rewriteSlice(slice, layout, convertOp);
12321098
}
12331099

0 commit comments

Comments
 (0)