Skip to content

Commit 3ddfd2a

Browse files
authored
Sync rematerialization heuristic with upstream. (#4700)
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 4e7e5ef commit 3ddfd2a

File tree

2 files changed

+114
-4
lines changed

2 files changed

+114
-4
lines changed

python/test/unit/language/test_core.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3092,9 +3092,6 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
30923092

30933093

30943094
def test_no_rematerialization_op(device):
3095-
if is_xpu():
3096-
pytest.skip("handle on XPU")
3097-
30983095
if torch.version.hip:
30993096
pytest.skip("test not supported on AMD")
31003097

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 114 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1382,12 +1382,125 @@ void LayoutRematerialization::backwardRematerialization(
13821382
return;
13831383
}
13841384

1385+
// 2. Determine whether rematerialisation is beneficial.
1386+
1387+
// Identify all operations in the slice
1388+
SetVector<Operation *> sliceOps;
1389+
for (Value v : slice) {
1390+
if (Operation *op = v.getDefiningOp()) {
1391+
sliceOps.insert(op);
1392+
}
1393+
}
1394+
1395+
// Compute single-use operations
1396+
DenseMap<Operation *, bool> isSingleUse;
1397+
std::function<bool(Operation *)> isOpSingleUse;
1398+
isOpSingleUse = [&](Operation *op) -> bool {
1399+
// lookup in memoization array:
1400+
auto it = isSingleUse.find(op);
1401+
if (it != isSingleUse.end()) {
1402+
return it->second;
1403+
}
1404+
1405+
bool singleUse = true;
1406+
1407+
for (Value result : op->getResults()) {
1408+
for (Operation *user : result.getUsers()) {
1409+
if (user == convertOp) {
1410+
continue;
1411+
}
1412+
if (sliceOps.contains(user)) {
1413+
if (!isOpSingleUse(user)) {
1414+
singleUse = false;
1415+
break;
1416+
}
1417+
} else {
1418+
singleUse = false;
1419+
break;
1420+
}
1421+
}
1422+
if (!singleUse) {
1423+
break;
1424+
}
1425+
}
1426+
1427+
// insert into memoization array:
1428+
isSingleUse[op] = singleUse;
1429+
return singleUse;
1430+
};
1431+
1432+
// Measure the number of bytes that we're manipulating with the
1433+
// ConvertLayoutOp. We pessimistically assume that we round-trip
1434+
// through shared memory and that we cannot vectorise sub-register
1435+
// loads/stores, so we set a minimum element count of 32 (the warp
1436+
// size and number of shared memory banks) and minimum bitwidth of
1437+
// 32 (the width per bank of the shared memory load/store unit).
1438+
int64_t convertLayoutBytes = getByteCount(convertOp.getSrc(), 32, 32);
1439+
1440+
// We measure costs in standardised milli-SM-cycles. The smem load
1441+
// and store each cost 8 * convertLayoutBytes, and then we double
1442+
// it to account for extra cost due to synchronisation.
1443+
int64_t convertLayoutCost = 32 * convertLayoutBytes;
1444+
int64_t rematerialisationCost = 0;
1445+
1446+
// Evaluate single-use status for every operation in slice
1447+
for (Operation *op : sliceOps) {
1448+
auto dialect = op->getDialect();
1449+
if (isOpSingleUse(op)) {
1450+
// when we rematerialise, this operation does not get duplicated
1451+
// so it does not contribute to our cost model:
1452+
continue;
1453+
} else if (isa<arith::ConstantOp>(op)) {
1454+
// special-case: arith.constant has zero cost
1455+
continue;
1456+
} else if (isa<LoadOp>(op) || isa<LocalLoadOp>(op)) {
1457+
// optimistically assume L1-cached:
1458+
for (Value result : op->getResults()) {
1459+
rematerialisationCost += 8 * getByteCount(result);
1460+
}
1461+
} else if (isa<arith::ArithDialect, math::MathDialect>(dialect)) {
1462+
// this is an arithmetic operation; we distinguish between cheap
1463+
// operations (such as floating point add/mul which can be fused
1464+
// as halves of a single-cycle FMA instruction) and expensive
1465+
// operations which use the special function unit and/or involve
1466+
// multiple instructions.
1467+
int64_t multiplier = isExpensiveMathOp(op) ? 8 : 1;
1468+
for (Value result : op->getResults()) {
1469+
rematerialisationCost += multiplier * getByteCount(result);
1470+
}
1471+
} else if (isa<ReduceOp>(op)) {
1472+
// Reduce op introduce much cost.
1473+
auto reduceOp = dyn_cast<ReduceOp>(op);
1474+
ReduceOpHelper helper(reduceOp);
1475+
if (!helper.isAssociative()) {
1476+
// We shouldn't rematerize a no associative reduce op if it has multiple
1477+
// use chain.
1478+
LDBG(" skipped rematerialization due to non-associative reduce in the "
1479+
"slice");
1480+
return;
1481+
}
1482+
rematerialisationCost += helper.getIntraWarpSizeWithUniqueData();
1483+
rematerialisationCost += 8 * helper.getInterWarpSizeWithUniqueData();
1484+
}
1485+
}
1486+
1487+
LLVM_DEBUG({
1488+
DBGS() << " convert layout cost: " << convertLayoutCost << "\n";
1489+
DBGS() << " rematerialisation cost: " << rematerialisationCost << "\n";
1490+
});
1491+
1492+
if (rematerialisationCost > convertLayoutCost) {
1493+
LDBG(" skipped rematerialization due to higher cost");
1494+
return;
1495+
}
1496+
13851497
LLVM_DEBUG({
13861498
DBGS() << " remat convert op " << convertOp << '\n';
13871499
for (Value v : slice)
13881500
DBGS() << " " << v << '\n';
13891501
});
1390-
// 2. Rewrite the slice.
1502+
1503+
// 3. Rewrite the slice.
13911504
rewriteSlice(slice, layout, convertOp);
13921505
}
13931506

0 commit comments

Comments
 (0)