Skip to content

Commit 35f1827

Browse files
authored
[BACKEND] Fix reduce with slice layout inputs (#5080)
Couple of places where not handling slice layout inputs for reductions. Add support for recursive slice layout in those cases.
1 parent 781774c commit 35f1827

File tree

3 files changed

+100
-17
lines changed

3 files changed

+100
-17
lines changed

lib/Analysis/Utility.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -69,18 +69,25 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
6969
}
7070

7171
unsigned threadOffset = 1;
72-
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(srcLayout)) {
73-
auto parentLayout = sliceLayout.getParent();
74-
auto threadsPerWarp = getThreadsPerWarp(parentLayout);
75-
threadOffset = threadsPerWarp[sliceLayout.getDim()];
76-
} else {
77-
auto threadsPerWarp = getThreadsPerWarp(srcLayout);
78-
auto order = getThreadOrder(srcLayout);
79-
for (unsigned i = 0; i < order.size(); i++) {
80-
if (order[i] == axis)
81-
break;
82-
threadOffset *= threadsPerWarp[order[i]];
83-
}
72+
SmallVector<int> dimsRemoved;
73+
while (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(srcLayout)) {
74+
dimsRemoved.push_back(sliceLayout.getDim());
75+
srcLayout = sliceLayout.getParent();
76+
}
77+
// In case of slice layout we want to know the axis dimension relative to the
78+
// most inner parent layout. `adjustedAxis` is the matching axis dim in the
79+
// parent layout.
80+
int adjustedAxis = axis;
81+
for (auto dim : dimsRemoved) {
82+
if (dim <= adjustedAxis)
83+
adjustedAxis++;
84+
}
85+
auto threadsPerWarp = getThreadsPerWarp(srcLayout);
86+
auto order = getThreadOrder(srcLayout);
87+
for (unsigned i = 0; i < order.size(); i++) {
88+
if (order[i] == adjustedAxis)
89+
break;
90+
threadOffset *= threadsPerWarp[order[i]];
8491
}
8592
return threadOffset;
8693
}

lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,46 @@ struct ReduceOpConversion
218218
rewriter.replaceOp(op, results);
219219
}
220220

221+
// For slice layout some ids are duplicated on multiple lanes, so we need to
222+
// handle the delinearization of laneId in a special way. We need to
223+
// generalize this part of the logic to work on any kind of linear layout
224+
// uniformely.
225+
SmallVector<Value>
226+
getMultiDimLaneId(ReduceOpHelper &helper, Value &laneId, Location &loc,
227+
ConversionPatternRewriter &rewriter) const {
228+
auto srcLayout = helper.getSrcLayout();
229+
auto srcShape = helper.getSrcShape();
230+
auto order = triton::gpu::getThreadOrder(srcLayout);
231+
SmallVector<Value> multiDimLaneId;
232+
233+
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(srcLayout)) {
234+
auto parentLayout = sliceLayout.getParent();
235+
SmallVector<unsigned> dims = {sliceLayout.getDim()};
236+
while (auto parentSliceLayout =
237+
mlir::dyn_cast<SliceEncodingAttr>(parentLayout)) {
238+
dims.push_back(parentSliceLayout.getDim());
239+
parentLayout = parentSliceLayout.getParent();
240+
}
241+
242+
auto parentThreadsPerWarps = triton::gpu::getThreadsPerWarp(parentLayout);
243+
auto parentOrder = triton::gpu::getThreadOrder(parentLayout);
244+
multiDimLaneId = delinearize(rewriter, loc, laneId, parentThreadsPerWarps,
245+
parentOrder);
246+
for (unsigned dim : llvm::reverse(dims)) {
247+
multiDimLaneId.erase(multiDimLaneId.begin() + dim);
248+
}
249+
} else {
250+
SmallVector<unsigned> threadsPerWarps =
251+
triton::gpu::getThreadsPerWarp(srcLayout);
252+
threadsPerWarps[helper.getAxis()] =
253+
triton::gpu::getThreadsPerWarpWithUniqueData(
254+
srcLayout, srcShape)[helper.getAxis()];
255+
multiDimLaneId =
256+
delinearize(rewriter, loc, laneId, threadsPerWarps, order);
257+
}
258+
return multiDimLaneId;
259+
}
260+
221261
SmallVector<Value>
222262
getMultiDimWarpId(ReduceOpHelper &helper, Value &warpId, Location &loc,
223263
ConversionPatternRewriter &rewriter) const {
@@ -231,11 +271,20 @@ struct ReduceOpConversion
231271
// a way to properly delinearize warpId in the slice case
232272
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(srcLayout)) {
233273
auto parentLayout = sliceLayout.getParent();
274+
SmallVector<unsigned> dims = {sliceLayout.getDim()};
275+
while (auto parentSliceLayout =
276+
mlir::dyn_cast<SliceEncodingAttr>(parentLayout)) {
277+
dims.push_back(parentSliceLayout.getDim());
278+
parentLayout = parentSliceLayout.getParent();
279+
}
280+
234281
auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(parentLayout);
235282
auto parentOrder = triton::gpu::getWarpOrder(parentLayout);
236283
multiDimWarpId =
237284
delinearize(rewriter, loc, warpId, parentWarpsPerCTA, parentOrder);
238-
multiDimWarpId.erase(multiDimWarpId.begin() + sliceLayout.getDim());
285+
for (unsigned dim : llvm::reverse(dims)) {
286+
multiDimWarpId.erase(multiDimWarpId.begin() + dim);
287+
}
239288
} else {
240289
SmallVector<unsigned> warpsPerCTA =
241290
triton::gpu::getWarpsPerCTA(srcLayout);
@@ -263,11 +312,8 @@ struct ReduceOpConversion
263312
unsigned axis = op.getAxis();
264313
auto smemShape = helper.getScratchRepShape();
265314

266-
auto threadsPerWarp =
267-
triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape);
268-
auto order = getThreadOrder(srcLayout);
269315
SmallVector<Value> multiDimLaneId =
270-
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
316+
getMultiDimLaneId(helper, laneId, loc, rewriter);
271317
Value laneIdAxis = multiDimLaneId[axis];
272318
Value zero = i32_val(0);
273319
Value laneZero = icmp_eq(laneIdAxis, zero);

python/test/unit/language/test_core.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6013,3 +6013,33 @@ def sanitize_cumsum_kernel(Z, X, BLOCK: tl.constexpr):
60136013
Z = torch.zeros_like(X)
60146014
sanitize_cumsum_kernel[(1, )](Z, X, BLOCK=BLOCK)
60156015
torch.testing.assert_close(Z, X.cumsum(0).to(torch.int32))
6016+
6017+
6018+
# stress test slice layout usages in reductions.
6019+
@pytest.mark.parametrize("in_shape, perm, red_dims", [
6020+
((4, 32, 32, 4, 2), [2, 1, 0, 3, 4], [3, 1, 0]),
6021+
((8, 2, 32, 4, 16), [4, 0, 1, 3, 2], [0, 2, 0]),
6022+
])
6023+
def test_chained_reductions(in_shape, perm, red_dims, device):
6024+
6025+
@triton.jit
6026+
def kernel(In, Out, #
6027+
dim_0: tl.constexpr, dim_1: tl.constexpr, dim_2: tl.constexpr, dim_3: tl.constexpr, dim_4: tl.constexpr,
6028+
perm_0: tl.constexpr, perm_1: tl.constexpr, perm_2: tl.constexpr, perm_3: tl.constexpr,
6029+
perm_4: tl.constexpr, red_dim_0: tl.constexpr, red_dim_1: tl.constexpr, red_dim_2: tl.constexpr):
6030+
idx = tl.arange(0, dim_0 * dim_1 * dim_2 * dim_3 * dim_4)
6031+
idx = idx.reshape(dim_0, dim_1, dim_2, dim_3, dim_4)
6032+
vals = tl.load(In + idx)
6033+
vals = tl.permute(vals, [perm_0, perm_1, perm_2, perm_3, perm_4])
6034+
r = tl.sum(tl.sum(tl.sum(vals, red_dim_0), red_dim_1), red_dim_2)
6035+
st_idx = tl.arange(0, r.shape[0] * r.shape[1]).reshape(r.shape)
6036+
tl.store(Out + st_idx, r)
6037+
6038+
input = torch.randint(0, 1000, in_shape, device=device, dtype=torch.int32)
6039+
temp = torch.permute(input, perm).contiguous()
6040+
ref = torch.sum(torch.sum(torch.sum(temp, dim=red_dims[0]), dim=red_dims[1]), dim=red_dims[2])
6041+
result = torch.empty_like(ref)
6042+
kernel[(1, )](input, result, input.shape[0], input.shape[1], input.shape[2], input.shape[3], input.shape[4],
6043+
perm[0], perm[1], perm[2], perm[3], perm[4], red_dims[0], red_dims[1], red_dims[2])
6044+
6045+
assert torch.all(ref == result)

0 commit comments

Comments
 (0)