@@ -1382,12 +1382,125 @@ void LayoutRematerialization::backwardRematerialization(
13821382 return ;
13831383 }
13841384
1385+ // 2. Determine whether rematerialisation is beneficial.
1386+
1387+ // Identify all operations in the slice
1388+ SetVector<Operation *> sliceOps;
1389+ for (Value v : slice) {
1390+ if (Operation *op = v.getDefiningOp ()) {
1391+ sliceOps.insert (op);
1392+ }
1393+ }
1394+
1395+ // Compute single-use operations
1396+ DenseMap<Operation *, bool > isSingleUse;
1397+ std::function<bool (Operation *)> isOpSingleUse;
1398+ isOpSingleUse = [&](Operation *op) -> bool {
1399+ // lookup in memoization array:
1400+ auto it = isSingleUse.find (op);
1401+ if (it != isSingleUse.end ()) {
1402+ return it->second ;
1403+ }
1404+
1405+ bool singleUse = true ;
1406+
1407+ for (Value result : op->getResults ()) {
1408+ for (Operation *user : result.getUsers ()) {
1409+ if (user == convertOp) {
1410+ continue ;
1411+ }
1412+ if (sliceOps.contains (user)) {
1413+ if (!isOpSingleUse (user)) {
1414+ singleUse = false ;
1415+ break ;
1416+ }
1417+ } else {
1418+ singleUse = false ;
1419+ break ;
1420+ }
1421+ }
1422+ if (!singleUse) {
1423+ break ;
1424+ }
1425+ }
1426+
1427+ // insert into memoization array:
1428+ isSingleUse[op] = singleUse;
1429+ return singleUse;
1430+ };
1431+
1432+ // Measure the number of bytes that we're manipulating with the
1433+ // ConvertLayoutOp. We pessimistically assume that we round-trip
1434+ // through shared memory and that we cannot vectorise sub-register
1435+ // loads/stores, so we set a minimum element count of 32 (the warp
1436+ // size and number of shared memory banks) and minimum bitwidth of
1437+ // 32 (the width per bank of the shared memory load/store unit).
1438+ int64_t convertLayoutBytes = getByteCount (convertOp.getSrc (), 32 , 32 );
1439+
1440+ // We measure costs in standardised milli-SM-cycles. The smem load
1441+ // and store each cost 8 * convertLayoutBytes, and then we double
1442+ // it to account for extra cost due to synchronisation.
1443+ int64_t convertLayoutCost = 32 * convertLayoutBytes;
1444+ int64_t rematerialisationCost = 0 ;
1445+
1446+ // Evaluate single-use status for every operation in slice
1447+ for (Operation *op : sliceOps) {
1448+ auto dialect = op->getDialect ();
1449+ if (isOpSingleUse (op)) {
1450+ // when we rematerialise, this operation does not get duplicated
1451+ // so it does not contribute to our cost model:
1452+ continue ;
1453+ } else if (isa<arith::ConstantOp>(op)) {
1454+ // special-case: arith.constant has zero cost
1455+ continue ;
1456+ } else if (isa<LoadOp>(op) || isa<LocalLoadOp>(op)) {
1457+ // optimistically assume L1-cached:
1458+ for (Value result : op->getResults ()) {
1459+ rematerialisationCost += 8 * getByteCount (result);
1460+ }
1461+ } else if (isa<arith::ArithDialect, math::MathDialect>(dialect)) {
1462+ // this is an arithmetic operation; we distinguish between cheap
1463+ // operations (such as floating point add/mul which can be fused
1464+ // as halves of a single-cycle FMA instruction) and expensive
1465+ // operations which use the special function unit and/or involve
1466+ // multiple instructions.
1467+ int64_t multiplier = isExpensiveMathOp (op) ? 8 : 1 ;
1468+ for (Value result : op->getResults ()) {
1469+ rematerialisationCost += multiplier * getByteCount (result);
1470+ }
1471+ } else if (isa<ReduceOp>(op)) {
1472+ // Reduce op introduce much cost.
1473+ auto reduceOp = dyn_cast<ReduceOp>(op);
1474+ ReduceOpHelper helper (reduceOp);
1475+ if (!helper.isAssociative ()) {
1476+ // We shouldn't rematerize a no associative reduce op if it has multiple
1477+ // use chain.
1478+ LDBG (" skipped rematerialization due to non-associative reduce in the "
1479+ " slice" );
1480+ return ;
1481+ }
1482+ rematerialisationCost += helper.getIntraWarpSizeWithUniqueData ();
1483+ rematerialisationCost += 8 * helper.getInterWarpSizeWithUniqueData ();
1484+ }
1485+ }
1486+
1487+ LLVM_DEBUG ({
1488+ DBGS () << " convert layout cost: " << convertLayoutCost << " \n " ;
1489+ DBGS () << " rematerialisation cost: " << rematerialisationCost << " \n " ;
1490+ });
1491+
1492+ if (rematerialisationCost > convertLayoutCost) {
1493+ LDBG (" skipped rematerialization due to higher cost" );
1494+ return ;
1495+ }
1496+
13851497 LLVM_DEBUG ({
13861498 DBGS () << " remat convert op " << convertOp << ' \n ' ;
13871499 for (Value v : slice)
13881500 DBGS () << " " << v << ' \n ' ;
13891501 });
1390- // 2. Rewrite the slice.
1502+
1503+ // 3. Rewrite the slice.
13911504 rewriteSlice (slice, layout, convertOp);
13921505}
13931506
0 commit comments