Skip to content

Commit 5cb1ebd

Browse files
authored
[XPU][OptEW] Allow multiple warps in non-sliced dimension (#2670)
Allow multiple warps in non-sliced dimension as long as there are `n*sub_group_size` contiguous elements per warp in the non-sliced dimension. --------- Signed-off-by: victor-eds <[email protected]>
1 parent 541dfab commit 5cb1ebd

File tree

2 files changed

+125
-8
lines changed

2 files changed

+125
-8
lines changed

test/TritonIntelGPU/optimize-elementwise.mlir

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,91 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
6363
tt.return %0 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
6464
}
6565
}
66+
67+
// -----
68+
69+
// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 1], order = [0, 1]}>
70+
// CHECK: #[[$ATTR_1:.+]] = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [2], order = [0]}>
71+
72+
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 1], order = [0, 1]}>
73+
74+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
75+
// CHECK-LABEL: tt.func @test_blocked_multi_warp(
76+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>,
77+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> {
78+
tt.func @test_blocked_multi_warp(%arg0: tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg1: tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> {
79+
// CHECK: %[[VAL_2:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<32xf32, #[[$ATTR_1]]>
80+
// CHECK: %[[VAL_3:.*]] = triton_gpu.convert_layout %[[VAL_1]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<32xf32, #[[$ATTR_1]]>
81+
// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_2]], %[[VAL_3]] : tensor<32xf32, #[[$ATTR_1]]>
82+
%0 = arith.addf %arg0, %arg1 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
83+
// CHECK: %[[VAL_5:.*]] = triton_gpu.convert_layout %[[VAL_4]] : tensor<32xf32, #[[$ATTR_1]]> -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>
84+
// CHECK: tt.return %[[VAL_5]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>
85+
tt.return %0 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
86+
}
87+
}
88+
89+
// -----
90+
91+
// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 1], order = [0, 1]}>
92+
// CHECK: #[[$ATTR_1:.+]] = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [16], warpsPerCTA = [4], order = [0]}>
93+
94+
#blocked = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 1], order = [0, 1]}>
95+
96+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
97+
// CHECK-LABEL: tt.func @test_blocked_multi_warp_double_stride(
98+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>,
99+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>) -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> {
100+
tt.func @test_blocked_multi_warp_double_stride(%arg0: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg1: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>> {
101+
// CHECK: %[[VAL_2:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<128xf16, #[[$ATTR_1]]>
102+
// CHECK: %[[VAL_3:.*]] = triton_gpu.convert_layout %[[VAL_1]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<128xf16, #[[$ATTR_1]]>
103+
// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_2]], %[[VAL_3]] : tensor<128xf16, #[[$ATTR_1]]>
104+
%0 = arith.addf %arg0, %arg1 : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
105+
// CHECK: %[[VAL_5:.*]] = triton_gpu.convert_layout %[[VAL_4]] : tensor<128xf16, #[[$ATTR_1]]> -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>
106+
// CHECK: tt.return %[[VAL_5]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>
107+
tt.return %0 : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
108+
}
109+
}
110+
111+
// -----
112+
113+
// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [8], order = [0]}>
114+
// CHECK: #[[$ATTR_1:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>
115+
116+
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>
117+
118+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
119+
// CHECK-LABEL: tt.func @test_mma_multi_warp_double_stride(
120+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>,
121+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>) -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> {
122+
tt.func @test_mma_multi_warp_double_stride(%arg0: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>, %arg1: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>> {
123+
// CHECK: %[[VAL_2:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<128xf16, #[[$ATTR_0]]>
124+
// CHECK: %[[VAL_3:.*]] = triton_gpu.convert_layout %[[VAL_1]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<128xf16, #[[$ATTR_0]]>
125+
// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_2]], %[[VAL_3]] : tensor<128xf16, #[[$ATTR_0]]>
126+
%0 = arith.addf %arg0, %arg1 : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>
127+
// CHECK: %[[VAL_5:.*]] = triton_gpu.convert_layout %[[VAL_4]] : tensor<128xf16, #[[$ATTR_0]]> -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
128+
// CHECK: tt.return %[[VAL_5]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
129+
tt.return %0 : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>
130+
}
131+
}
132+
133+
// -----
134+
135+
// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [2], order = [0]}>
136+
// CHECK: #[[$ATTR_1:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>
137+
138+
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>
139+
140+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
141+
// CHECK-LABEL: tt.func @test_mma_multi_warp_double_stride_repeat(
142+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>,
143+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>) -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> {
144+
tt.func @test_mma_multi_warp_double_stride_repeat(%arg0: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>, %arg1: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>> {
145+
// CHECK: %[[VAL_2:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<128xf16, #[[$ATTR_0]]>
146+
// CHECK: %[[VAL_3:.*]] = triton_gpu.convert_layout %[[VAL_1]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<128xf16, #[[$ATTR_0]]>
147+
// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_2]], %[[VAL_3]] : tensor<128xf16, #[[$ATTR_0]]>
148+
%0 = arith.addf %arg0, %arg1 : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>
149+
// CHECK: %[[VAL_5:.*]] = triton_gpu.convert_layout %[[VAL_4]] : tensor<128xf16, #[[$ATTR_0]]> -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
150+
// CHECK: tt.return %[[VAL_5]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
151+
tt.return %0 : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>
152+
}
153+
}

third_party/intel/lib/TritonIntelGPUTransforms/OptimizeElementwiseParallelism.cpp

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,36 @@ namespace mlir::triton::gpu::intel {
2424
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc"
2525

2626
namespace {
27+
bool isMultiWarpValidLayoutForUnbroadcast(const LinearLayout &linearLayout,
28+
int32_t numWorkGroupPos,
29+
PatternRewriter &rewriter) {
30+
StringAttr kLane = rewriter.getStringAttr("lane");
31+
StringAttr kWarp = rewriter.getStringAttr("warp");
32+
int32_t subGroupSize = linearLayout.getInDimSize(kLane);
33+
ArrayRef<int32_t> numContiguousPerWarp = linearLayout.getBasis(kWarp, 0);
34+
// Check the warp dimension hasn't been sliced away and we have n *
35+
// sub_group_size contiguous elements per warp.
36+
if (numContiguousPerWarp == ArrayRef<int32_t>{0} ||
37+
numContiguousPerWarp[0] % subGroupSize != 0)
38+
return false;
39+
int32_t expectedValue = numContiguousPerWarp[0] * 2;
40+
for (int32_t pos = 1; pos < numWorkGroupPos; ++pos) {
41+
if (linearLayout.getBasis(kWarp, pos) != ArrayRef<int32_t>{expectedValue})
42+
return false;
43+
expectedValue *= 2;
44+
}
45+
return true;
46+
}
47+
2748
/// Return whether the input linear layout can be unbroadcasted.
2849
///
2950
/// A layout is valid for being "unbroadcasted" along its lanes if:
3051
/// - The 'lane' input dimension is zero: this means the lane dimension has been
3152
/// sliced.
3253
/// - The size of the input 'block' dimension is 1. This is true for XPU
3354
/// backend.
34-
/// - The size of the input 'warp' dimension is 1. This is a limitation to keep
35-
/// things simple for now.
55+
/// - The size of the input 'warp' dimension is 1 or there are n*sub_group_size
56+
/// contiguous elements per warp.
3657
///
3758
/// Broadcasted layouts are layouts with sliced lane, warp or block (not
3859
/// possible for XPU backend) dimensions, i.e., the same data is owned by
@@ -49,8 +70,11 @@ bool isValidLayoutForUnbroadcast(const LinearLayout &linearLayout,
4970
// Only single block for now.
5071
if (linearLayout.getInDimSize(kBlock) != 1)
5172
return false;
52-
// Only single warp for now.
53-
return linearLayout.getInDimSize(kWarp) == 1;
73+
// 'warp' dimension hasn't been sliced away and there are n*sub_group_size
74+
// contiguous elements in each warp (or there is a single warp).
75+
int32_t numWorkGroupPos = linearLayout.getInDimSizeLog2(kWarp);
76+
return numWorkGroupPos == 0 || isMultiWarpValidLayoutForUnbroadcast(
77+
linearLayout, numWorkGroupPos, rewriter);
5478
}
5579

5680
/// Get optimized unbroadcasted tensor type.
@@ -61,18 +85,23 @@ bool isValidLayoutForUnbroadcast(const LinearLayout &linearLayout,
6185
RankedTensorType getOptimizedType(RankedTensorType type,
6286
const LinearLayout &linearLayout,
6387
PatternRewriter &rewriter) {
88+
StringAttr kWarp = rewriter.getStringAttr("warp");
89+
6490
auto encoding = cast<DistributedEncodingTrait>(type.getEncoding());
6591
unsigned threadsPerWarp = product(encoding.getThreadsPerWarp());
66-
[[maybe_unused]] unsigned warpsPerCTA = product(encoding.getWarpsPerCTA());
67-
assert(warpsPerCTA == 1 && "Expecting single warp");
92+
unsigned warpsPerCTA = product(encoding.getWarpsPerCTA());
6893
[[maybe_unused]] unsigned ctaSplitNum = product(encoding.getCTASplitNum());
6994
assert(ctaSplitNum == 1 && "Expecting single CTA");
7095

7196
RankedTensorType::Builder builder(type);
97+
int32_t numWorkGroupPos = linearLayout.getInDimSizeLog2(kWarp);
98+
unsigned sizePerThread =
99+
numWorkGroupPos == 0
100+
? 1
101+
: linearLayout.getBasis(kWarp, 0)[0] / threadsPerWarp;
72102
CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(rewriter.getContext(), 1);
73103
auto newEncoding = rewriter.getAttr<BlockedEncodingAttr>(
74-
/*sizePerThread=*/1, threadsPerWarp, /*warpsPerCTA=*/1, /*order=*/0,
75-
ctaLayout);
104+
sizePerThread, threadsPerWarp, warpsPerCTA, /*order=*/0, ctaLayout);
76105
builder.setEncoding(newEncoding);
77106
return builder;
78107
}

0 commit comments

Comments
 (0)