Skip to content

Commit 2c59df5

Browse files
wenqinyapgoucher
andauthored
[IR] tune the rematerialization heuristic to avoid harmful rematerialization (#7240)
Fix #6647 This PR supports more operators to calculate `rematerialisationCost`, it controls whether we should rematerialize a slice, which contain a `ttg.convert_layout`, or not. In the above issue, there are some `ttg.local_load` and `tt.reduce` operators in the slice, the do much shared memory load and reduce computation, but in they are not be accounted into current `rematerialisationCost`, so it may cause some harmful rematerialization. I add some heuristic for these operators, PTAL, thanks! <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) --------- Co-authored-by: apgoucher <[email protected]>
1 parent 6c3d943 commit 2c59df5

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1193,7 +1193,7 @@ void LayoutRematerialization::backwardRematerialization(
11931193
} else if (isa<arith::ConstantOp>(op)) {
11941194
// special-case: arith.constant has zero cost
11951195
continue;
1196-
} else if (isa<LoadOp>(op)) {
1196+
} else if (isa<LoadOp>(op) || isa<LocalLoadOp>(op)) {
11971197
// optimistically assume L1-cached:
11981198
for (Value result : op->getResults()) {
11991199
rematerialisationCost += 8 * getByteCount(result);
@@ -1208,6 +1208,12 @@ void LayoutRematerialization::backwardRematerialization(
12081208
for (Value result : op->getResults()) {
12091209
rematerialisationCost += multiplier * getByteCount(result);
12101210
}
1211+
} else if (isa<ReduceOp>(op)) {
1212+
// Reduce op introduce much cost.
1213+
auto reduceOp = dyn_cast<ReduceOp>(op);
1214+
ReduceOpHelper helper(reduceOp);
1215+
rematerialisationCost += helper.getIntraWarpSizeWithUniqueData();
1216+
rematerialisationCost += 8 * helper.getInterWarpSizeWithUniqueData();
12111217
}
12121218
}
12131219

python/test/unit/language/test_core.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2972,6 +2972,50 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
29722972
]
29732973

29742974

2975+
def test_no_rematerialization_op():
2976+
2977+
if torch.version.hip:
2978+
pytest.skip("test not supported on AMD")
2979+
2980+
@triton.jit
2981+
def kernel(
2982+
input_data,
2983+
sum_output,
2984+
out_1,
2985+
BLOCK_SIZE: tl.constexpr,
2986+
DATA_DIM: tl.constexpr,
2987+
DATA_LEN: tl.constexpr,
2988+
loop_stages: tl.constexpr,
2989+
):
2990+
tl.static_assert(DATA_LEN % BLOCK_SIZE == 0)
2991+
for curr_block_idx in tl.range(0, DATA_LEN // BLOCK_SIZE, num_stages=loop_stages):
2992+
my_idxs = BLOCK_SIZE * curr_block_idx + tl.arange(0, BLOCK_SIZE)
2993+
values = tl.load(input_data + DATA_DIM * my_idxs[:, None] + tl.arange(0, DATA_DIM)[None, :])
2994+
accum = tl.sum(values, axis=-1).to(tl.float32)
2995+
tl.store(sum_output + my_idxs, accum)
2996+
sum_plus_0 = tl.full((1, 2), 0, tl.float32) + accum[:, None]
2997+
tl.store(out_1 + my_idxs[:, None] * 2 + tl.arange(0, 2)[None, :], sum_plus_0)
2998+
2999+
device = "cuda"
3000+
data_len = 32
3001+
data_dim = 64
3002+
torch.manual_seed(0)
3003+
input_data = torch.randn((data_len, data_dim), dtype=torch.float32, device=device)
3004+
sum_output = torch.full((data_len, ), -1, dtype=torch.float32, device=device)
3005+
out_1 = torch.full((data_len, 2), -1, dtype=torch.float32, device=device)
3006+
compiled_kernel = kernel[(1, )](
3007+
input_data=input_data,
3008+
sum_output=sum_output,
3009+
out_1=out_1,
3010+
DATA_DIM=data_dim,
3011+
DATA_LEN=data_len,
3012+
BLOCK_SIZE=16,
3013+
num_warps=1,
3014+
loop_stages=2,
3015+
)
3016+
assert compiled_kernel.asm["ttgir"].count('"tt.reduce"') == 1, "we shouldn't rematerialize tt.reduce"
3017+
3018+
29753019
@pytest.mark.parametrize("M, N", [[32, 16], [32, 32], [32, 64], [64, 32]])
29763020
@pytest.mark.parametrize("src_layout", scan_layouts)
29773021
@pytest.mark.parametrize("axis", [0, 1])

0 commit comments

Comments
 (0)