Skip to content

Commit 0d9c0d3

Browse files
authored
[XPU][OptEW] Define -intel-triton-optimize-elementwise-parallelism pass (#2631)
Define pass improving elementwise parallelism by avoiding layout conversions leading to data duplication between threads. See pass documentation for more information. --------- Signed-off-by: victor-eds <[email protected]>
1 parent ee755e8 commit 0d9c0d3

File tree

4 files changed

+274
-0
lines changed

4 files changed

+274
-0
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// RUN: triton-opt %s --split-input-file -tritonintelgpu-optimize-elementwise-parallelism | FileCheck %s
2+
3+
// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}>
4+
// CHECK: #[[$ATTR_1:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>
5+
6+
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>
7+
8+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
9+
// CHECK-LABEL: tt.func @test_dpas(
10+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>,
11+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>)
12+
tt.func @test_dpas(%arg0: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, %arg1: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> {
13+
// CHECK: %[[VAL_2:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]>
14+
// CHECK: %[[VAL_3:.*]] = triton_gpu.convert_layout %[[VAL_1]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]>
15+
// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_2]], %[[VAL_3]] : tensor<16xf32, #[[$ATTR_0]]>
16+
%0 = arith.addf %arg0, %arg1 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
17+
// CHECK: %[[VAL_5:.*]] = triton_gpu.convert_layout %[[VAL_4]] : tensor<16xf32, #[[$ATTR_0]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
18+
// CHECK: tt.return %[[VAL_5]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
19+
tt.return %0 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
20+
}
21+
}
22+
23+
// -----
24+
25+
// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
26+
// CHECK: #[[$ATTR_1:.+]] = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}>
27+
28+
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
29+
30+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
31+
// CHECK-LABEL: tt.func @test_blocked(
32+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>,
33+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>)
34+
tt.func @test_blocked(%arg0: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg1: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> {
35+
// CHECK: %[[VAL_2:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<16xf32, #[[$ATTR_1]]>
36+
// CHECK: %[[VAL_3:.*]] = triton_gpu.convert_layout %[[VAL_1]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<16xf32, #[[$ATTR_1]]>
37+
// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_2]], %[[VAL_3]] : tensor<16xf32, #[[$ATTR_1]]>
38+
%0 = arith.addf %arg0, %arg1 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
39+
// CHECK: %[[VAL_5:.*]] = triton_gpu.convert_layout %[[VAL_4]] : tensor<16xf32, #[[$ATTR_1]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>
40+
// CHECK: tt.return %[[VAL_5]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>
41+
tt.return %0 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
42+
}
43+
}
44+
45+
// -----
46+
47+
// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
48+
// CHECK: #[[$ATTR_1:.+]] = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}>
49+
50+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
51+
52+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
53+
// CHECK-LABEL: tt.func @test_blocked_repeat(
54+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>,
55+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>)
56+
tt.func @test_blocked_repeat(%arg0: tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg1: tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> {
57+
// CHECK: %[[VAL_2:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<64xf32, #[[$ATTR_1]]>
58+
// CHECK: %[[VAL_3:.*]] = triton_gpu.convert_layout %[[VAL_1]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<64xf32, #[[$ATTR_1]]>
59+
// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_2]], %[[VAL_3]] : tensor<64xf32, #[[$ATTR_1]]>
60+
%0 = arith.addf %arg0, %arg1 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
61+
// CHECK: %[[VAL_5:.*]] = triton_gpu.convert_layout %[[VAL_4]] : tensor<64xf32, #[[$ATTR_1]]> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>
62+
// CHECK: tt.return %[[VAL_5]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>
63+
tt.return %0 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
64+
}
65+
}

third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,4 +365,52 @@ tt.func @test(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slic
365365
"mlir::triton::gpu::TritonGPUDialect"];
366366
}
367367

368+
def TritonIntelGPUOptimizeElementwiseParallelism
369+
: Pass<"tritonintelgpu-optimize-elementwise-parallelism", "mlir::ModuleOp"> {
370+
let summary =
371+
"Improve parallelism of elementwise operations better utilizing hardware resources.";
372+
373+
let description = [{
374+
Detect elementwise operations with an encoding causing sub-par parallelism,
375+
i.e., with data duplication across threads, and convert the operands to a
376+
more optimal encoding if the cost of doing so is heuristically estimated to
377+
be sufficiently low. As of now, the cost should be 0, we only support
378+
"unbroadcasting" tensors, i.e., dropping duplicated values held in other
379+
threads by re-distributing them.
380+
381+
As an example, this pass would modify the following code:
382+
```mlir
383+
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
384+
385+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
386+
tt.func @test_blocked(%arg0: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg1: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> {
387+
%0 = arith.addf %arg0, %arg1 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
388+
tt.return %0 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
389+
}
390+
}
391+
```
392+
Obtaining:
393+
```mlir
394+
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
395+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}>
396+
397+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
398+
tt.func @test_blocked(%arg0: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg1: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> {
399+
%0 = triton_gpu.convert_layout %arg0 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16xf32, #blocked1>
400+
%1 = triton_gpu.convert_layout %arg1 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16xf32, #blocked1>
401+
%2 = arith.addf %0, %1 : tensor<16xf32, #blocked1>
402+
%3 = triton_gpu.convert_layout %2 : tensor<16xf32, #blocked1> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
403+
tt.return %3 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
404+
}
405+
}
406+
```
407+
408+
Note how the converted tensors are not sliced and thus each element in the
409+
tensor is held by a single thread.
410+
}];
411+
412+
let dependentDialects = [];
413+
}
414+
415+
368416
#endif // TRITON_INTEL_GPU_PASSES

third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_triton_library(TritonIntelGPUTransforms
44
DistributeToWarps.cpp
55
MatchTargetSize.cpp
66
MaterializeBlockPointer.cpp
7+
OptimizeElementwiseParallelism.cpp
78
OptimizeReductionLocality.cpp
89
Pipeliner/MatmulLoopPipeline.cpp
910
Pipeliner/SoftwarePipeliner.cpp
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
//===- OptimizeElementwiseParallelism.cpp -------------------------------*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
/// This file implements the `tritonintelgpu-optimize-elementwise-parallelism`
9+
/// pass.
10+
//===----------------------------------------------------------------------===//
11+
12+
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"
13+
14+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15+
16+
#include "triton/Dialect/Triton/IR/Dialect.h"
17+
#include "triton/Dialect/Triton/IR/Utility.h"
18+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
19+
20+
#define DEBUG_TYPE "tritonintelgpu-optimize-elementwise-parallelism"
21+
22+
namespace mlir::triton::gpu::intel {
23+
#define GEN_PASS_DEF_TRITONINTELGPUOPTIMIZEELEMENTWISEPARALLELISM
24+
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc"
25+
26+
namespace {
27+
/// Return whether the input linear layout can be unbroadcasted.
28+
///
29+
/// A layout is valid for being "unbroadcasted" along its lanes if:
30+
/// - The 'lane' input dimension is zero: this means the lane dimension has been
31+
/// sliced.
32+
/// - The size of the input 'block' dimension is 1. This is true for XPU
33+
/// backend.
34+
/// - The size of the input 'warp' dimension is 1. This is a limitation to keep
35+
/// things simple for now.
36+
///
37+
/// Broadcasted layouts are layouts with sliced lane, warp or block (not
38+
/// possible for XPU backend) dimensions, i.e., the same data is owned by
39+
/// different threads.
40+
bool isValidLayoutForUnbroadcast(const LinearLayout &linearLayout,
41+
PatternRewriter &rewriter) {
42+
StringAttr kLane = rewriter.getStringAttr("lane");
43+
StringAttr kWarp = rewriter.getStringAttr("warp");
44+
StringAttr kBlock = rewriter.getStringAttr("block");
45+
StringAttr kDim0 = rewriter.getStringAttr("dim0");
46+
// 'lane' dimension must have been sliced away completely.
47+
if (!linearLayout.sublayoutIsZero(kLane, kDim0))
48+
return false;
49+
// Only single block for now.
50+
if (linearLayout.getInDimSize(kBlock) != 1)
51+
return false;
52+
// Only single warp for now.
53+
return linearLayout.getInDimSize(kWarp) == 1;
54+
}
55+
56+
/// Get optimized unbroadcasted tensor type.
57+
///
58+
/// Get optimized ranked tensor type after unbroadcasting. As we only support 1D
59+
/// tensors, this is as simple as getting an "unboradcasted" blocked-encoded 1D
60+
/// tensor type.
61+
RankedTensorType getOptimizedType(RankedTensorType type,
62+
const LinearLayout &linearLayout,
63+
PatternRewriter &rewriter) {
64+
auto encoding = cast<DistributedEncodingTrait>(type.getEncoding());
65+
unsigned threadsPerWarp = product(encoding.getThreadsPerWarp());
66+
[[maybe_unused]] unsigned warpsPerCTA = product(encoding.getWarpsPerCTA());
67+
assert(warpsPerCTA == 1 && "Expecting single warp");
68+
[[maybe_unused]] unsigned ctaSplitNum = product(encoding.getCTASplitNum());
69+
assert(ctaSplitNum == 1 && "Expecting single CTA");
70+
71+
RankedTensorType::Builder builder(type);
72+
CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(rewriter.getContext(), 1);
73+
auto newEncoding = rewriter.getAttr<BlockedEncodingAttr>(
74+
/*sizePerThread=*/1, threadsPerWarp, /*warpsPerCTA=*/1, /*order=*/0,
75+
ctaLayout);
76+
builder.setEncoding(newEncoding);
77+
return builder;
78+
}
79+
80+
struct ElementwiseOptPattern final
81+
: OpTraitRewritePattern<OpTrait::Elementwise> {
82+
using OpTraitRewritePattern<OpTrait::Elementwise>::OpTraitRewritePattern;
83+
84+
LogicalResult matchAndRewrite(Operation *op,
85+
PatternRewriter &rewriter) const final {
86+
// Rely on this for a simpler pass.
87+
if (!op->hasTrait<OpTrait::SameOperandsAndResultType>() ||
88+
op->getNumResults() != 1)
89+
return failure();
90+
91+
// Skip complex operations.
92+
if (op->hasSuccessors() || op->getNumRegions() != 0)
93+
return failure();
94+
95+
// Layout optimizations only apply to tensors.
96+
auto type = dyn_cast<RankedTensorType>(op->getResultTypes().front());
97+
if (!type)
98+
return failure();
99+
100+
// Check if the layout is actually bad and can be optimized using our
101+
// approach. We only support 1D tensors for now as these are easier to
102+
// handle.
103+
Attribute layout = type.getEncoding();
104+
if (!layout || type.getRank() != 1)
105+
return failure();
106+
std::optional<LinearLayout> linearLayout =
107+
toLinearLayout(type.getShape(), layout);
108+
if (!linearLayout || !isValidLayoutForUnbroadcast(*linearLayout, rewriter))
109+
return failure();
110+
111+
// Check the operands are not used by other operations. This will prevent
112+
// register pressure increase:
113+
if (!llvm::all_of(op->getOperands(),
114+
[](Value val) { return val.hasOneUse(); }))
115+
return failure();
116+
117+
// As we are dealing with 1D tensors, we can do a simple transform to obtain
118+
// a more optimized operation.
119+
Location loc = op->getLoc();
120+
RankedTensorType newType = getOptimizedType(type, *linearLayout, rewriter);
121+
SmallVector<Value> newOperands(op->getNumOperands());
122+
llvm::transform(op->getOperands(), std::begin(newOperands),
123+
[&rewriter, loc, newType](Value operand) {
124+
return rewriter.create<ConvertLayoutOp>(loc, newType,
125+
operand);
126+
});
127+
128+
// Now we create the optimized operation:
129+
StringAttr opName = op->getName().getIdentifier();
130+
ArrayRef<NamedAttribute> attributes = op->getAttrs();
131+
Operation *newElementwiseOp =
132+
rewriter.create(loc, opName, newOperands, newType, attributes);
133+
assert(newElementwiseOp->getNumResults() == 1 &&
134+
"Expecting single result operation");
135+
136+
// Convert to unoptimized encoding for further use.
137+
Value newValue = newElementwiseOp->getResult(0);
138+
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(op, type, newValue);
139+
140+
return success();
141+
}
142+
};
143+
144+
struct TritonIntelGPUOptimizeElementwiseParallelism final
145+
: impl::TritonIntelGPUOptimizeElementwiseParallelismBase<
146+
TritonIntelGPUOptimizeElementwiseParallelism> {
147+
using Base::Base;
148+
149+
void runOnOperation() final {
150+
Operation *op = getOperation();
151+
MLIRContext *ctx = op->getContext();
152+
RewritePatternSet patterns(ctx);
153+
patterns.add<ElementwiseOptPattern>(ctx);
154+
if (failed(
155+
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
156+
signalPassFailure();
157+
}
158+
};
159+
} // namespace
160+
} // namespace mlir::triton::gpu::intel

0 commit comments

Comments
 (0)