Skip to content

Commit 5b7bc04

Browse files
plognjenoplavsicantiagainst
authored
[AMD] Rewrite extract_slice op implementation (#7128)
This PR refactors the extract_slice operation to support two major improvements: 1) Relaxed Layout Constraints The operation now allows more flexible source and destination layouts, aligning better with linear layouts. 2) Support for Arbitrary Tensor Ranks extract_slice is no longer limited to 2D tensors and can now handle tensors of any rank. The "extract_slice" operation enables extracting a slice of a tensor in registers. It supports the following arguments: * source: the base tensor on which to create a view tensor * offsets: offsets into the base tensor at which to create the view In distributed layouts, tensors are divided into CTA tiles. A CTA tile represents the smallest contiguous portion of a tensor that is distributed across all threads and warps within a workgroup. The ExtractSlice operation extracts a portion of the tensor that aligns with CTA tile boundaries. This op is designed to work on logical tensors directly, avoiding the need for complex layout reinterpretation or reshaping. For example, the tt.split operation only supports splitting along the innermost dimension, and requires that the resulting innermost dimension provide 2 elements per thread, distributed across registers. In contrast, extract_slice op imposes no constraints on the extraction dimension or the size of dimensions. --------- Co-authored-by: Ognjen Plavsic <[email protected]> Co-authored-by: Lei Zhang <[email protected]>
1 parent 235496e commit 5b7bc04

File tree

14 files changed

+466
-315
lines changed

14 files changed

+466
-315
lines changed

test/Conversion/amd/invalid_extractslice_to_llvm.mlir

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,37 +3,17 @@
33
// Invalid size
44
#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
55
tt.func @invalid_size_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) {
6-
// expected-error @+1 {{sizes [256, 2] must be a multiple of shapePerCTATile [256, 16]}}
6+
// expected-error @+1 {{result shape must be multiple of shapePerCTATile}}
77
%1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x2xi32, #blocked1>
88
tt.return
99
}
1010

1111
// -----
1212

13-
// Invalid zero source dimension
14-
#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
15-
tt.func @invalid_size_input(%arg0: tensor<256x0xi32, #blocked1> {tt.divisibility = 16 : i32}) {
16-
// expected-error @+1 {{source tensor dimension size zero at dimension 1}}
17-
%1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x0xi32, #blocked1> to tensor<256x16xi32, #blocked1>
18-
tt.return
19-
}
20-
21-
// -----
22-
23-
// Invalid zero result dimension
24-
#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
25-
tt.func @invalid_size_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) {
26-
// expected-error @+1 {{result tensor dimension size zero at dimension 1}}
27-
%1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x0xi32, #blocked1>
28-
tt.return
29-
}
30-
31-
// -----
32-
3313
// Invalid offset, not multiple of shapePerTile
3414
#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
3515
tt.func @invalid_offset_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) {
36-
// expected-error @+1 {{offset [0, 5] must be a multiple of shapePerCTATile [256, 16]}}
16+
// expected-error @+1 {{offset must be multiple of shapePerCTATile}}
3717
%1 = amdgpu.extract_slice %arg0 [0,5] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1>
3818
tt.return
3919
}
@@ -43,7 +23,7 @@ tt.func @invalid_offset_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibi
4323
// Invalid offset, out of bounds for dimension
4424
#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
4525
tt.func @invalid_offset_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) {
46-
// expected-error @+1 {{invalid offset 128 at dimension 1}}
26+
// expected-error @+1 {{invalid offset at dimension 1}}
4727
%1 = amdgpu.extract_slice %arg0 [0,128] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1>
4828
tt.return
4929
}
@@ -54,11 +34,10 @@ tt.func @invalid_offset_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibi
5434
#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
5535
#blocked2 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
5636
tt.func @invalid_result_layout(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) {
57-
// expected-error @+1 {{result layout must match source layout}}
37+
// expected-error @+1 {{CTA tile shapes must match between source and destination tensors.}}
5838
%1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked2>
5939
tt.return
6040
}
61-
6241
// -----
6342

6443
// Invalid result element type
@@ -84,23 +63,13 @@ tt.func @invalid_result_rank(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibil
8463
// Invalid result shape
8564
#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
8665
tt.func @invalid_result_rank(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) {
87-
// expected-error @+1 {{result shape cannot be larger than input shape at dimension 1}}
66+
// expected-error @+1 {{result shape cannot exceed source shape at dimension 1}}
8867
%1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x256xi32, #blocked1>
8968
tt.return
9069
}
9170

9271
// -----
9372

94-
// Invalid rank
95-
#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
96-
tt.func @invalid_rank(%arg0: tensor<256x128x2xi32, #blocked1> {tt.divisibility = 16 : i32}) {
97-
// expected-error @+1 {{currently only 2D tensors are supported}}
98-
%1 = amdgpu.extract_slice %arg0 [0,0,0] : tensor<256x128x2xi32, #blocked1> to tensor<256x16x2xi32, #blocked1>
99-
tt.return
100-
}
101-
102-
// -----
103-
10473
// Invalid non static offset
10574
#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
10675
tt.func @invalid_non_static_offset(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}, %arg1: i32) {
@@ -109,3 +78,25 @@ tt.func @invalid_non_static_offset(%arg0: tensor<256x128xi32, #blocked1> {tt.div
10978
%2 = amdgpu.extract_slice %arg0 [%arg1, 0] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1>
11079
tt.return
11180
}
81+
82+
// -----
83+
84+
// Invalid layout 1
85+
#dst_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [64, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}>
86+
#src_layout = #ttg.linear<{register=[[0, 0], [0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [0, 128], [64, 0], [128, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}>
87+
tt.func @invalid_register_base(%arg0: tensor<256x256xi32, #src_layout> {tt.divisibility = 16 : i32}) {
88+
// expected-error @+1 {{Register basis must match on a CTA tile between source and destination}}
89+
%2 = amdgpu.extract_slice %arg0 [0, 0] : tensor<256x256xi32, #src_layout> to tensor<128x128xi32, #dst_layout>
90+
tt.return
91+
}
92+
93+
// -----
94+
95+
// Invalid layout 2
96+
#dst_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [64, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}>
97+
#src_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [0, 128], [64, 0], [128, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4], [0, 0]], warp=[[0, 32], [32, 0]], block=[]}>
98+
tt.func @invalid_lane_warp_basis(%arg0: tensor<256x256xi32, #src_layout> {tt.divisibility = 16 : i32}) {
99+
// expected-error @+1 {{Lane and warp dim basis must match between source and destination layout}}
100+
%2 = amdgpu.extract_slice %arg0 [0, 0] : tensor<256x256xi32, #src_layout> to tensor<128x128xi32, #dst_layout>
101+
tt.return
102+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// RUN: triton-opt %s -split-input-file -canonicalize | FileCheck %s
2+
3+
// -----
4+
5+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
6+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
7+
tt.func @canonicalize_after_concat(
8+
%arg0: tensor<32x64xf32, #blocked>,
9+
%arg1: tensor<32x64xf32, #blocked>,
10+
%arg2: tensor<32x64xf32, #blocked>,
11+
%arg3: tensor<32x64xf32, #blocked>,
12+
%arg4: tensor<32x64xf32, #blocked>,
13+
%arg5: tensor<32x64xf32, #blocked>,
14+
%arg6: tensor<32x64xf32, #blocked>,
15+
%arg7: tensor<32x64xf32, #blocked>) -> tensor<32x64xf32, #blocked> {
16+
// CHECK-LABEL: tt.func @canonicalize_after_concat
17+
18+
%1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7:
19+
tensor<32x64xf32, #blocked>,tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked> -> tensor<128x128xf32, #blocked>
20+
%2 = amdgpu.extract_slice %1 [32, 64] : tensor<128x128xf32, #blocked> to tensor<32x64xf32, #blocked>
21+
// CHECK: tt.return %arg3 : tensor<32x64xf32, #blocked>
22+
tt.return %2 : tensor<32x64xf32, #blocked>
23+
}
24+
}
Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,57 @@
1-
// RUN: triton-opt %s --convert-triton-amdgpu-to-llvm="arch=gfx942" | FileCheck %s
1+
// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942" | FileCheck %s
22

33
#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
44
#blocked2 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
55
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} {
6-
tt.func @basic_insert_slice(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) {
7-
// CHECK: llvm.func @basic_insert_slice
8-
// CHECK-COUNT-64: %{{[0-9]*}} = llvm.extractvalue %arg0[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
9-
// CHECK: %64 = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
10-
// CHECK-COUNT-8: %{{[0-9]*}} = llvm.insertvalue %{{[0-9]*}}, %{{[0-9]*}}[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
6+
tt.func @extract_2d_blocked_tensor(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) {
7+
// CHECK-LABEL: llvm.func @extract_2d_blocked_tensor
8+
// CHECK-COUNT-64: %{{.*}} = llvm.extractvalue %{{.*}} : !llvm.struct
9+
// CHECK-COUNT-8: %{{.*}} = llvm.insertvalue %{{.*}} : !llvm.struct
1110
%72 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1>
1211
tt.return
1312
}
1413
}
14+
15+
// -----
16+
17+
#ll1 = #ttg.linear<{register = [[1, 0], [2, 0], [4, 0], [0, 16], [0, 32], [0, 64]], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [8, 0], [16, 0]], warp = [[32, 0], [64, 0], [128, 0]], block = []}>
18+
#ll2 = #ttg.linear<{register = [[1, 0], [2, 0], [4, 0]], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [8, 0], [16, 0]], warp = [[32, 0], [64, 0], [128, 0]], block = []}>
19+
20+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} {
21+
tt.func @extract_2d_linear_tensor(%arg0: tensor<256x128xi32, #ll1> {tt.divisibility = 16 : i32}) {
22+
// CHECK-LABEL: llvm.func @extract_2d_linear_tensor
23+
// CHECK-COUNT-64: %{{.*}} = llvm.extractvalue %arg0[{{[0-9]*}}] : !llvm.struct
24+
// CHECK-COUNT-8: %{{.*}} = llvm.insertvalue %{{.*}} : !llvm.struct
25+
%72 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #ll1> to tensor<256x16xi32, #ll2>
26+
tt.return
27+
}
28+
}
29+
30+
// -----
31+
32+
#ll1 = #ttg.linear<{register = [[0, 1, 0], [0, 2, 0], [0, 4, 0], [0, 0, 16], [0, 0, 32], [0, 0, 64], [1, 0, 0]], lane = [[0, 0, 1], [0, 0, 2], [0, 0, 4], [0, 0, 8], [0, 8, 0], [0, 16, 0]], warp = [[0, 32, 0], [0, 64, 0], [0, 128, 0]], block = []}>
33+
#ll2 = #ttg.linear<{register = [[0, 1, 0], [0, 2, 0], [0, 4, 0], [0, 0, 16], [0, 0, 32], [0, 0, 64]], lane = [[0, 0, 1], [0, 0, 2], [0, 0, 4], [0, 0, 8], [0, 8, 0], [0, 16, 0]], warp = [[0, 32, 0], [0, 64, 0], [0, 128, 0]], block = []}>
34+
35+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} {
36+
tt.func @extract_3d_linear_tensor(%arg0: tensor<2x256x128xi32, #ll1> {tt.divisibility = 16 : i32}) {
37+
// CHECK-LABEL: llvm.func @extract_3d_linear_tensor
38+
// CHECK-COUNT-128: %{{.*}} = llvm.extractvalue %arg0[{{.*}}] : !llvm.struct
39+
// CHECK-COUNT-64: %{{[0-9]*}} = llvm.insertvalue %{{.*}} : !llvm.struct
40+
%72 = amdgpu.extract_slice %arg0 [0,0,0] : tensor<2x256x128xi32, #ll1> to tensor<1x256x128xi32, #ll2>
41+
tt.return
42+
}
43+
}
44+
45+
// -----
46+
47+
#ll1 = #ttg.linear<{register=[[1], [256], [512]], lane=[[2], [4], [8], [16], [32], [64]], warp=[[128]], block=[]}>
48+
#ll2 = #ttg.linear<{register=[[1]], lane=[[2], [4], [8], [16], [32], [64]], warp=[[128]], block=[]}>
49+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} {
50+
tt.func @extract_1d_linear_tensor(%arg0: tensor<1024xi32, #ll1> {tt.divisibility = 16 : i32}) {
51+
// CHECK-LABEL: llvm.func @extract_1d_linear_tensor
52+
// CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg0[{{.*}}] : !llvm.struct
53+
// CHECK-COUNT-2: %{{[0-9]*}} = llvm.insertvalue %{{.*}} : !llvm.struct
54+
%72 = amdgpu.extract_slice %arg0 [0] : tensor<1024xi32, #ll1> to tensor<256xi32, #ll2>
55+
tt.return
56+
}
57+
}

third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,53 @@ def ExtractSliceOp : TT_AMDGPU_Op<"extract_slice", [Pure]> {
6565
* source: the base tensor on which to create a view tensor
6666
* offsets: offsets into the base tensor at which to create the view
6767

68+
In distributed layouts, tensors are divided into CTA tiles.
69+
A CTA tile represents the smallest contiguous portion of a tensor that is
70+
distributed across all threads and warps within a workgroup.
71+
The ExtractSlice operation extracts a portion of the tensor that is a
72+
multiple of CTA tiles.
73+
74+
The source and destination must have matching linear layouts at the CTA
75+
tile level. This ensures that the extract_slice is a no-op, meaning no data
76+
rearrangement between threads is required to extract the destination tensor
77+
with the given shape and layout.
78+
79+
+-------+-------+
80+
| W0 | W1 |
81+
| | |
82+
| + | + |
83+
| W2 | W3 | <-- Single CTA tile (distributed across warps W0-W3)
84+
| | |
85+
| + | + |
86+
| | |
87+
+-------+-------+
88+
| Source Tensor Extracted Slice
89+
| . +--------------+
90+
| . | W0 | W1 |
91+
| . | | |
92+
| | + | + |
93+
| | W2 | W3 |
94+
| | | |
95+
| | + | + |
96+
| | | |
97+
| +-------+------+
98+
| | W0 | W1 |
99+
| | | |
100+
| | + | + |
101+
| | W2 W3 |
102+
| | | |
103+
| | + | + |
104+
| | | |
105+
| +--------------+
106+
107+
108+
This op is designed to work on logical tensors directly, avoiding the need
109+
for complex layout reinterpretation or reshaping. For example, the tt.split
110+
operation only supports splitting along the innermost dimension,
111+
and requires that the resulting innermost dimension provide 2 elements per thread,
112+
distributed across registers. In contrast, extract_slice op imposes no constraints
113+
on the extraction dimension or the size of dimensions.
114+
68115
Example 1:
69116

70117
```mlir
@@ -80,11 +127,11 @@ def ExtractSliceOp : TT_AMDGPU_Op<"extract_slice", [Pure]> {
80127
```
81128

82129
Example 1 shows how "extract_slice" operation may be used. In this example a
83-
new slice of 128x32 is created. "extract_slice" works on tensors with layout
84-
where the desired slice has the same layout as the source tensor.
85-
"%0" cannot be sliced directly as the resulting slice cannot have the same
86-
layout as "%0". Therefore it needs to be converted to a layout suitable
87-
for slicing. "#blocked1" layout is appropriate for this as it keeps the
130+
new slice of 128x32 is created. "extract_slice" works on tensors
131+
where the desired slice has the same layout on a CTA tile as the source tensor.
132+
"%0" cannot be sliced directly as the resulting slice does not satisfy this condition.
133+
Therefore it needs to be converted to a layout suitable for slicing.
134+
"#blocked1" layout is appropriate for this as it keeps the
88135
sizePerThread the same thus keeping coalescing properties the same.
89136
In order to utilize all threads in a warp, "threadsPerWarp" is set to
90137
[16,4] for this new layout. This layout conversion carried out before
@@ -117,6 +164,7 @@ def ExtractSliceOp : TT_AMDGPU_Op<"extract_slice", [Pure]> {
117164
}];
118165

119166
let hasVerifier = 1;
167+
let hasCanonicalizer = 1;
120168
}
121169

122170
def ConcatOp : TT_AMDGPU_Op<"concat", [Pure]> {
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#ifndef TRITON_THIRD_PARTY_AMD_INCLUDE_UTILS_UTILITY_H_
2+
#define TRITON_THIRD_PARTY_AMD_INCLUDE_UTILS_UTILITY_H_
3+
4+
#include "llvm/ADT/ArrayRef.h"
5+
#include <cassert>
6+
#include <vector>
7+
namespace mlir::LLVM::AMD {
8+
9+
template <typename T, typename U, typename BinaryOp>
10+
std::vector<unsigned> multiDimElementwise(const ArrayRef<T> &lhs,
11+
const ArrayRef<U> &rhs, BinaryOp op) {
12+
assert(lhs.size() == rhs.size() && "Input dimensions must match");
13+
std::vector<unsigned> result;
14+
result.reserve(lhs.size());
15+
for (size_t i = 0, n = lhs.size(); i < n; ++i) {
16+
unsigned a = static_cast<unsigned>(lhs[i]);
17+
unsigned b = static_cast<unsigned>(rhs[i]);
18+
result.push_back(op(a, b));
19+
}
20+
return result;
21+
}
22+
} // namespace mlir::LLVM::AMD
23+
#endif // TRITON_THIRD_PARTY_AMD_INCLUDE_UTILS_UTILITY_H_

0 commit comments

Comments
 (0)