Skip to content

Commit 3c09bfe

Browse files
Fix test_chained_reductions (#2821)
Fixes #2703 --------- Signed-off-by: Tiotto, Ettore <[email protected]> Co-authored-by: Whitney Tsang <[email protected]>
1 parent 8589959 commit 3c09bfe

File tree

8 files changed

+56
-17
lines changed

8 files changed

+56
-17
lines changed

python/test/unit/language/test_core.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6109,6 +6109,11 @@ def sanitize_cumsum_kernel(Z, X, BLOCK: tl.constexpr):
61096109
((8, 2, 32, 4, 16), [4, 0, 1, 3, 2], [0, 2, 0]),
61106110
])
61116111
def test_chained_reductions(in_shape, perm, red_dims, device):
6112+
if is_xpu() and in_shape == (4, 32, 32, 4, 2):
6113+
# check maximum shared memory
6114+
if triton.runtime.driver.active.utils.get_device_properties(
6115+
triton.runtime.driver.active.get_current_device())["max_shared_mem"] <= 163840:
6116+
pytest.xfail("XPU: Not enough shared memory")
61126117

61136118
@triton.jit
61146119
def kernel(In, Out, #

scripts/skiplist/a770/language.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2703
2-
test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0]
31
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16]
42
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float32]
53
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float32-float32]

scripts/skiplist/conda/language.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,6 @@ test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e5-128-
113113
test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e4nv-128-256-128-128-256-256]
114114
test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e4b15-128-256-128-128-256-256]
115115
test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[128-float8e5-128-256-128-128-256-256]
116-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2703
117-
test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0]
118116
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16]
119117
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float32]
120118
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float32-float32]
Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +0,0 @@
1-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2703
2-
test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0]

scripts/skiplist/lts/language.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,6 @@ test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e5-128-
113113
test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e4nv-128-256-128-128-256-256]
114114
test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e4b15-128-256-128-128-256-256]
115115
test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[128-float8e5-128-256-128-128-256-256]
116-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2703
117-
test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0]
118116
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16]
119117
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float32]
120118
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float32-float32]

scripts/skiplist/mtl/language.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2703
2-
test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0]
31
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16]
42
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float32]
53
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float32-float32]

scripts/skiplist/xe2/language.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +0,0 @@
1-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2703
2-
test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0]

third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,46 @@ struct ReduceOpConversion
220220
rewriter.replaceOp(op, results);
221221
}
222222

223+
// For slice layout some ids are duplicated on multiple lanes, so we need to
224+
// handle the delinearization of laneId in a special way. We need to
225+
// generalize this part of the logic to work on any kind of linear layout
226+
// uniformely.
227+
SmallVector<Value>
228+
getMultiDimLaneId(ReduceOpHelper &helper, Value &laneId, Location &loc,
229+
ConversionPatternRewriter &rewriter) const {
230+
auto srcLayout = helper.getSrcLayout();
231+
auto srcShape = helper.getSrcShape();
232+
auto order = triton::gpu::getThreadOrder(srcLayout);
233+
SmallVector<Value> multiDimLaneId;
234+
235+
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(srcLayout)) {
236+
auto parentLayout = sliceLayout.getParent();
237+
SmallVector<unsigned> dims = {sliceLayout.getDim()};
238+
while (auto parentSliceLayout =
239+
mlir::dyn_cast<SliceEncodingAttr>(parentLayout)) {
240+
dims.push_back(parentSliceLayout.getDim());
241+
parentLayout = parentSliceLayout.getParent();
242+
}
243+
244+
auto parentThreadsPerWarps = triton::gpu::getThreadsPerWarp(parentLayout);
245+
auto parentOrder = triton::gpu::getThreadOrder(parentLayout);
246+
multiDimLaneId = delinearize(rewriter, loc, laneId, parentThreadsPerWarps,
247+
parentOrder);
248+
for (unsigned dim : llvm::reverse(dims)) {
249+
multiDimLaneId.erase(multiDimLaneId.begin() + dim);
250+
}
251+
} else {
252+
SmallVector<unsigned> threadsPerWarps =
253+
triton::gpu::getThreadsPerWarp(srcLayout);
254+
threadsPerWarps[helper.getAxis()] =
255+
triton::gpu::getThreadsPerWarpWithUniqueData(
256+
srcLayout, srcShape)[helper.getAxis()];
257+
multiDimLaneId =
258+
delinearize(rewriter, loc, laneId, threadsPerWarps, order);
259+
}
260+
return multiDimLaneId;
261+
}
262+
223263
SmallVector<Value>
224264
getMultiDimWarpId(ReduceOpHelper &helper, Value &warpId, Location &loc,
225265
ConversionPatternRewriter &rewriter) const {
@@ -233,11 +273,20 @@ struct ReduceOpConversion
233273
// a way to properly delinearize warpId in the slice case
234274
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(srcLayout)) {
235275
auto parentLayout = sliceLayout.getParent();
276+
SmallVector<unsigned> dims = {sliceLayout.getDim()};
277+
while (auto parentSliceLayout =
278+
mlir::dyn_cast<SliceEncodingAttr>(parentLayout)) {
279+
dims.push_back(parentSliceLayout.getDim());
280+
parentLayout = parentSliceLayout.getParent();
281+
}
282+
236283
auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(parentLayout);
237284
auto parentOrder = triton::gpu::getWarpOrder(parentLayout);
238285
multiDimWarpId =
239286
delinearize(rewriter, loc, warpId, parentWarpsPerCTA, parentOrder);
240-
multiDimWarpId.erase(multiDimWarpId.begin() + sliceLayout.getDim());
287+
for (unsigned dim : llvm::reverse(dims)) {
288+
multiDimWarpId.erase(multiDimWarpId.begin() + dim);
289+
}
241290
} else {
242291
SmallVector<unsigned> warpsPerCTA =
243292
triton::gpu::getWarpsPerCTA(srcLayout);
@@ -265,11 +314,8 @@ struct ReduceOpConversion
265314
unsigned axis = op.getAxis();
266315
auto smemShape = helper.getScratchRepShape();
267316

268-
auto threadsPerWarp =
269-
triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape);
270-
auto order = getThreadOrder(srcLayout);
271317
SmallVector<Value> multiDimLaneId =
272-
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
318+
getMultiDimLaneId(helper, laneId, loc, rewriter);
273319
Value laneIdAxis = multiDimLaneId[axis];
274320
Value zero = i32_val(0);
275321
Value laneZero = icmp_eq(laneIdAxis, zero);

0 commit comments

Comments
 (0)