Skip to content

Commit a6a3f13

Browse files
committed
Address comments
1 parent 6fa6f7e commit a6a3f13

File tree

3 files changed

+43
-45
lines changed

3 files changed

+43
-45
lines changed

test/TritonIntelGPU/optimize-reduction.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
// RUN: triton-opt %s --split-input-file -tritonintelgpu-optimize-reduction-locality -canonicalize | FileCheck %s
1+
// RUN: triton-opt %s --split-input-file -tritonintelgpu-optimize-reduction-locality | FileCheck %s
22

33
// Test reduction in a single warp (16x16->16).
44

55
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 1]}>
66

7-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {
7+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} {
88

99
// CHECK-DAG: #[[$ATTR_2:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}>
1010
// CHECK-DAG: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1], order = [4, 0, 1, 2, 3]}>
@@ -50,7 +50,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
5050

5151
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 1]}>
5252

53-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {
53+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} {
5454

5555
// CHECK-DAG: #[[$ATTR_5:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}>
5656
// CHECK-DAG: #[[$ATTR_3:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 1, 1], order = [4, 0, 1, 2, 3]}>
@@ -101,7 +101,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 :
101101
// CHECK-DAG: #[[$ATTR_7:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 2], order = [1, 0]}>
102102
// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 1, 2], order = [2, 0, 1]}>
103103

104-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {
104+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} {
105105

106106
// CHECK-LABEL: tt.func @test_two_warps_red(
107107
// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32xf32, #[[$ATTR_8]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_8]]}>> {
@@ -147,7 +147,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 :
147147
// CHECK-DAG: #[[$ATTR_11:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}>
148148
// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [2, 1, 2], order = [2, 0, 1]}>
149149

150-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {
150+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} {
151151

152152
// CHECK-LABEL: tt.func @test_two_warps(
153153
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x32xf32, #[[$ATTR_11]]>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> {
@@ -225,7 +225,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
225225

226226
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [2, 2]}>
227227

228-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {
228+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} {
229229
// CHECK: tt.func @test(
230230
// CHECK-SAME: %[[VAL_0:.*]]: tensor<64x64xf32, #[[$ATTR_14]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> {
231231
// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<64x64xf32, #[[$ATTR_14]]> -> tensor<64x16x2x2x1xf32, #[[$ATTR_12]]>

third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -294,16 +294,13 @@ def TritonIntelGPUOptimizeReductionLocality
294294
`triton_gpu.convert_layout` operations, e.g.:
295295
```mlir
296296
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}>
297-
298-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {
299-
tt.func @test.work(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> {
300-
%0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({
301-
^bb0(%arg1: f32, %arg2: f32):
302-
%1 = arith.addf %arg1, %arg2 : f32
303-
tt.reduce.return %1 : f32
304-
}) : (tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>>
305-
tt.return %0 : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>>
306-
}
297+
tt.func @test(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> {
298+
%0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({
299+
^bb0(%arg1: f32, %arg2: f32):
300+
%1 = arith.addf %arg1, %arg2 : f32
301+
tt.reduce.return %1 : f32
302+
}) : (tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>>
303+
tt.return %0 : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>>
307304
}
308305
```
309306
Is converted to:
@@ -312,29 +309,27 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
312309
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [2, 1, 2], order = [2, 0, 1]}>
313310
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 2], order = [1, 0]}>
314311
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1], A = [8, 8], B = [8, 16], C = [8, 16]}>
315-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {
316-
tt.func @test_two_warps_twice(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> {
317-
%0 = tt.reshape %arg0 {allow_reorder = true} : tensor<32x32xf32, #mma> -> tensor<32x16x1x2x1xf32, #blocked>
318-
%1 = "tt.reduce"(%0) <{axis = 4 : i32}> ({
319-
^bb0(%arg1: f32, %arg2: f32):
320-
%7 = arith.addf %arg1, %arg2 : f32
321-
tt.reduce.return %7 : f32
322-
}) : (tensor<32x16x1x2x1xf32, #blocked>) -> tensor<32x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #blocked}>>
323-
%2 = "tt.reduce"(%1) <{axis = 2 : i32}> ({
324-
^bb0(%arg1: f32, %arg2: f32):
325-
%7 = arith.addf %arg1, %arg2 : f32
326-
tt.reduce.return %7 : f32
327-
}) : (tensor<32x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #blocked}>>) -> tensor<32x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>>
328-
%3 = triton_gpu.convert_layout %2 : tensor<32x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>> -> tensor<32x16x2xf32, #blocked1>
329-
%4 = tt.reshape %3 {allow_reorder = true} : tensor<32x16x2xf32, #blocked1> -> tensor<32x32xf32, #blocked2>
330-
%5 = "tt.reduce"(%4) <{axis = 1 : i32}> ({
331-
^bb0(%arg1: f32, %arg2: f32):
332-
%7 = arith.addf %arg1, %arg2 : f32
333-
tt.reduce.return %7 : f32
334-
}) : (tensor<32x32xf32, #blocked2>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
335-
%6 = triton_gpu.convert_layout %5 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
336-
tt.return %6 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
337-
}
312+
tt.func @test(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> {
313+
%0 = tt.reshape %arg0 {allow_reorder = true} : tensor<32x32xf32, #mma> -> tensor<32x16x1x2x1xf32, #blocked>
314+
%1 = "tt.reduce"(%0) <{axis = 4 : i32}> ({
315+
^bb0(%arg1: f32, %arg2: f32):
316+
%7 = arith.addf %arg1, %arg2 : f32
317+
tt.reduce.return %7 : f32
318+
}) : (tensor<32x16x1x2x1xf32, #blocked>) -> tensor<32x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #blocked}>>
319+
%2 = "tt.reduce"(%1) <{axis = 2 : i32}> ({
320+
^bb0(%arg1: f32, %arg2: f32):
321+
%7 = arith.addf %arg1, %arg2 : f32
322+
tt.reduce.return %7 : f32
323+
}) : (tensor<32x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #blocked}>>) -> tensor<32x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>>
324+
%3 = triton_gpu.convert_layout %2 : tensor<32x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>> -> tensor<32x16x2xf32, #blocked1>
325+
%4 = tt.reshape %3 {allow_reorder = true} : tensor<32x16x2xf32, #blocked1> -> tensor<32x32xf32, #blocked2>
326+
%5 = "tt.reduce"(%4) <{axis = 1 : i32}> ({
327+
^bb0(%arg1: f32, %arg2: f32):
328+
%7 = arith.addf %arg1, %arg2 : f32
329+
tt.reduce.return %7 : f32
330+
}) : (tensor<32x32xf32, #blocked2>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
331+
%6 = triton_gpu.convert_layout %5 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
332+
tt.return %6 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
338333
}
339334
```
340335
The `tt.reshape` operation is a NOP so that the following `tt.reduce`

third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace mlir::triton::gpu::intel {
2323

2424
namespace {
2525
static CTALayoutAttr getIdentityCTALayoutAttr(PatternRewriter &rewriter,
26-
std::size_t rank) {
26+
size_t rank) {
2727
SmallVector<unsigned> ctasPerCGA(rank, 1);
2828
SmallVector<unsigned> ctaSplitNum(rank, 1);
2929
SmallVector<unsigned> ctaOrder(rank);
@@ -121,7 +121,7 @@ static Value createReshapeForReduction(PatternRewriter &rewriter, Location loc,
121121
/// And reducing on dimension 1 and converting the layout to the original one
122122
/// leads to the same output as the original operation.
123123
// clang-format on
124-
struct DPasOperandPattern final : OpRewritePattern<ReduceOp> {
124+
struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
125125
using OpRewritePattern<ReduceOp>::OpRewritePattern;
126126

127127
static constexpr int preferredNonReductionAxis = 0;
@@ -197,6 +197,7 @@ struct DPasOperandPattern final : OpRewritePattern<ReduceOp> {
197197
return success();
198198
}
199199

200+
private:
200201
Value reshapeForElementWiseReduction(ReduceOp op,
201202
PatternRewriter &rewriter) const {
202203
assert(op.getOperands().size() == 1 && "Expecting a single operand");
@@ -206,7 +207,7 @@ struct DPasOperandPattern final : OpRewritePattern<ReduceOp> {
206207
ArrayRef<int64_t> oldShape = oldType.getShape();
207208
auto oldEncoding = cast<DpasEncodingAttr>(oldType.getEncoding());
208209

209-
constexpr std::size_t rank = 5;
210+
constexpr size_t rank = 5;
210211
std::array<int64_t, rank> shape{
211212
// Y axis
212213
oldShape[0],
@@ -245,6 +246,8 @@ struct DPasOperandPattern final : OpRewritePattern<ReduceOp> {
245246

246247
Value performReduction(ReduceOp op, PatternRewriter &rewriter, Value val,
247248
int axis) const {
249+
assert(axis >= 0 && "Expecting positive axis");
250+
248251
auto newOp = rewriter.create<ReduceOp>(op.getLoc(), val, /*axis=*/axis);
249252
auto &newCombineOp = newOp.getCombineOp();
250253
rewriter.cloneRegionBefore(op.getCombineOp(), newCombineOp,
@@ -275,7 +278,7 @@ struct DPasOperandPattern final : OpRewritePattern<ReduceOp> {
275278
cast<RankedTensorType>(op.getOperands().front().getType())
276279
.getEncoding());
277280

278-
constexpr std::size_t rank = 3;
281+
constexpr size_t rank = 3;
279282
ArrayRef<int64_t> shape = oldType.getShape();
280283
std::array<unsigned, rank> sizePerThread{1, dpasEncoding.getExecutionSize(),
281284
1};
@@ -301,7 +304,7 @@ struct DPasOperandPattern final : OpRewritePattern<ReduceOp> {
301304
ArrayRef<int64_t> oldShape = oldType.getShape();
302305
auto oldEncoding = cast<BlockedEncodingAttr>(oldType.getEncoding());
303306

304-
constexpr std::size_t rank = 2;
307+
constexpr size_t rank = 2;
305308
std::array<int64_t, rank> shape{oldShape[0], oldShape[1] * oldShape[2]};
306309
std::array<unsigned, rank> sizePerThread{1,
307310
oldEncoding.getSizePerThread()[1]};
@@ -346,7 +349,7 @@ struct TritonIntelGPUOptimizeReductionLocality final
346349
Operation *op = getOperation();
347350
MLIRContext *ctx = op->getContext();
348351
RewritePatternSet patterns(ctx);
349-
patterns.add<DPasOperandPattern>(ctx);
352+
patterns.add<DpasOperandPattern>(ctx);
350353
if (failed(
351354
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
352355
signalPassFailure();

0 commit comments

Comments
 (0)