Skip to content

Commit 76ed95b

Browse files
plognjenoplavsicantiagainst
authored
[AMD] Add a Concat op to AMDGPU dialect (#6590)
The "concat" operation combines a list of source n-dimensional tensors into a single larger destination tensor. All source tensors must have the same shape, element type, and encoding. The concatenation dimension is inferred from the source and destination shapes provided by the user. For example, two tensors of shape 64x128 can produce a destination shape of 128x128, indicating concatenation along dimension 0; or 64x256, indicating concatenation along dimension 1. Generally, source tensors passed as op arguments can be arranged into the resulting shape in multiple ways. For example, given four tensors of shape 64x64: concat s0<64x64>, s1<64x64>, s2<64x64>, s3<64x64> -> <128x128> They can be laid out in different configurations within the result tensor: 1) s0 s1 s2 s3 2) s0 s2 s1 s3 From a logical tensor perspective, the source tensors are treated as elements of a tensor of tensors. In other words, the 1-D array of input tensors is conceptually reshaped into an n-D grid. The semantics of this op assume a row-major order (or its n-D generalization), meaning the fastest-varying dimension is filled first, and the slowest-varying dimension is filled last. In the example above, this corresponds to layout 1). The source and destination tensors must have identical linear layouts at the CTA tile level. That is, all base vectors for input dimensions must match, except for the register input dimension. The register basis must align on the subset that defines the logical tensor shape of a single CTA tile. This ensures that the concatenation is a no-op, meaning no data rearrangement among threads is required to assemble the destination tensor with the given shape and layout. However, the order of CTA tiles within the layout does not need to match between source and destination layouts. It is the responsibility of the op's lowering logic to handle this correctly. This op is designed to work on logical tensors directly, avoiding the need for complex layout reinterpretation or reshaping. For example, the `tt.join` operation only supports concatenation along the innermost dimension, and requires that the resulting innermost dimension provide 2 elements per thread, distributed across registers. In contrast, this `concat` op imposes no constraints on the concatenation dimension or the size of dimensions. --------- Co-authored-by: Ognjen Plavsic <[email protected]> Co-authored-by: Lei Zhang <[email protected]>
1 parent 9b13c1c commit 76ed95b

File tree

11 files changed

+844
-113
lines changed

11 files changed

+844
-113
lines changed

.github/workflows/integration-tests-amd.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ jobs:
109109
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
110110
fi
111111
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
112-
pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice.py
112+
pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice_concat_op.py
113113
TRITON_ALWAYS_COMPILE=1 pytest --capture=tee-sys -rfs third_party/amd/python/test/test_scalarize_packed_fops.py
114114
cd python/test/unit
115115
pytest --capture=tee-sys -rfs -n 12 language runtime \
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
// RUN: triton-opt -split-input-file %s --convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics
2+
3+
4+
// Invalid ranks
5+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
6+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
7+
tt.func @invalid_concat(
8+
%arg0: tensor<32x64xf32, #blocked>,
9+
%arg1: tensor<32x64xf32, #blocked>,
10+
%arg2: tensor<32x64xf32, #blocked>,
11+
%arg3: tensor<32x64xf32, #blocked>,
12+
%arg4: tensor<32x64xf32, #blocked>,
13+
%arg5: tensor<32x64xf32, #blocked>,
14+
%arg6: tensor<32x64xf32, #blocked>,
15+
%arg7: tensor<32x64xf32, #blocked>) {
16+
17+
// expected-error @+1 {{Source and destination tensors must have the same rank.}}
18+
%1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7:
19+
tensor<32x64xf32, #blocked>,tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked> -> tensor<256xf32, #blocked>
20+
tt.return
21+
}
22+
}
23+
24+
// -----
25+
26+
// Invalid shapes 1
27+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
28+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
29+
tt.func @invalid_concat(
30+
%arg0: tensor<32x64xf32, #blocked>,
31+
%arg1: tensor<32x64xf32, #blocked>,
32+
%arg2: tensor<32x64xf32, #blocked>,
33+
%arg3: tensor<32x64xf32, #blocked>,
34+
%arg4: tensor<32x64xf32, #blocked>,
35+
%arg5: tensor<32x64xf32, #blocked>,
36+
%arg6: tensor<32x64xf32, #blocked>,
37+
%arg7: tensor<32x64xf32, #blocked>) {
38+
39+
// expected-error @+1 {{Source and destination tensor shapes don't match.}}
40+
%1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7:
41+
tensor<32x64xf32, #blocked>,tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked> -> tensor<257x128xf32, #blocked>
42+
tt.return
43+
}
44+
}
45+
46+
// -----
47+
48+
// Invalid shapes 2
49+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
50+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
51+
tt.func @invalid_concat(
52+
%arg0: tensor<32x64xf32, #blocked>,
53+
%arg1: tensor<32x64xf32, #blocked>,
54+
%arg2: tensor<32x64xf32, #blocked>,
55+
%arg3: tensor<32x64xf32, #blocked>,
56+
%arg4: tensor<32x64xf32, #blocked>,
57+
%arg5: tensor<32x64xf32, #blocked>,
58+
%arg6: tensor<32x64xf32, #blocked>,
59+
%arg7: tensor<32x64xf32, #blocked>) {
60+
61+
// expected-error @+1 {{Number of source tiles (8) doesn't match required count (16).}}
62+
%1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7:
63+
tensor<32x64xf32, #blocked>,tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked> -> tensor<256x128xf32, #blocked>
64+
tt.return
65+
}
66+
}
67+
68+
69+
// -----
70+
71+
// Invalid shapes 3
72+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
73+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
74+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
75+
tt.func @invalid_concat(
76+
%arg0: tensor<32x64xf32, #blocked>,
77+
%arg1: tensor<32x64xf32, #blocked>,
78+
%arg2: tensor<32x64xf32, #blocked>,
79+
%arg3: tensor<32x64xf32, #blocked>,
80+
%arg4: tensor<32x64xf32, #blocked>,
81+
%arg5: tensor<32x64xf32, #blocked>,
82+
%arg6: tensor<32x64xf32, #blocked>,
83+
%arg7: tensor<32x64xf32, #blocked>) {
84+
85+
// expected-error @+1 {{CTA tile shapes must match between source and destination tensors.}}
86+
%1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7:
87+
tensor<32x64xf32, #blocked>,tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked> -> tensor<128x128xf32, #blocked1>
88+
tt.return
89+
}
90+
}
91+
92+
// -----
93+
94+
// Different types
95+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
96+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
97+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
98+
tt.func @invalid_concat(
99+
%arg0: tensor<32x64xf32, #blocked1>,
100+
%arg1: tensor<32x64xf32, #blocked>,
101+
%arg2: tensor<32x64xf32, #blocked>,
102+
%arg3: tensor<32x64xf32, #blocked>,
103+
%arg4: tensor<32x64xf32, #blocked>,
104+
%arg5: tensor<32x64xf32, #blocked>,
105+
%arg6: tensor<32x64xf32, #blocked>,
106+
%arg7: tensor<32x64xf32, #blocked>) {
107+
108+
// expected-error @+1 {{All sources must have identical tensor types.}}
109+
%1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7:
110+
tensor<32x64xf32, #blocked1>,tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked> -> tensor<128x128xf32, #blocked>
111+
tt.return
112+
}
113+
}
114+
115+
// -----
116+
117+
// Invalid element types
118+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
119+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
120+
tt.func @invalid_concat(
121+
%arg0: tensor<32x64xf32, #blocked>,
122+
%arg1: tensor<32x64xf32, #blocked>,
123+
%arg2: tensor<32x64xf32, #blocked>,
124+
%arg3: tensor<32x64xf32, #blocked>,
125+
%arg4: tensor<32x64xf32, #blocked>,
126+
%arg5: tensor<32x64xf32, #blocked>,
127+
%arg6: tensor<32x64xf32, #blocked>,
128+
%arg7: tensor<32x64xf32, #blocked>) {
129+
130+
// expected-error @+1 {{Element types of sources and destination must match.}}
131+
%1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7:
132+
tensor<32x64xf32, #blocked>,tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked> -> tensor<256x64xf16, #blocked>
133+
tt.return
134+
}
135+
}
136+
137+
138+
// -----
139+
140+
// Different layouts 1
141+
#src_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [64, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}>
142+
#dst_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [0, 128], [64, 0], [128, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4], [0, 0]], warp=[[0, 32], [32, 0]], block=[]}>
143+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
144+
tt.func @invalid_concat(
145+
%arg0: tensor<128x128xf32, #src_layout>,
146+
%arg1: tensor<128x128xf32, #src_layout>,
147+
%arg2: tensor<128x128xf32, #src_layout>,
148+
%arg3: tensor<128x128xf32, #src_layout>) {
149+
150+
// expected-error @+1 {{Lane and warp dim basis must match between source and destination layout.}}
151+
%1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3:
152+
tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout> -> tensor<256x256xf32, #dst_layout>
153+
tt.return
154+
}
155+
}
156+
157+
// -----
158+
159+
// Different layouts 2
160+
#src_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [64, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}>
161+
#dst_layout = #ttg.linear<{register=[[0, 0], [0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [0, 128], [64, 0], [128, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}>
162+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
163+
tt.func @invalid_concat(
164+
%arg0: tensor<128x128xf32, #src_layout>,
165+
%arg1: tensor<128x128xf32, #src_layout>,
166+
%arg2: tensor<128x128xf32, #src_layout>,
167+
%arg3: tensor<128x128xf32, #src_layout>) {
168+
169+
// expected-error @+1 {{Register basis must match on a CTA tile between source and destination.}}
170+
%1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3:
171+
tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout> -> tensor<256x256xf32, #dst_layout>
172+
tt.return
173+
}
174+
}

test/TritonGPU/amd/amd-concat-op.mlir

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm='arch=gfx942' | FileCheck %s
2+
3+
// -----
4+
5+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
6+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
7+
tt.func @concat_blocked(
8+
%arg0: tensor<32x64xf32, #blocked1>,
9+
%arg1: tensor<32x64xf32, #blocked1>,
10+
%arg2: tensor<32x64xf32, #blocked1>,
11+
%arg3: tensor<32x64xf32, #blocked1>,
12+
%arg4: tensor<32x64xf32, #blocked1>,
13+
%arg5: tensor<32x64xf32, #blocked1>,
14+
%arg6: tensor<32x64xf32, #blocked1>,
15+
%arg7: tensor<32x64xf32, #blocked1>) {
16+
// CHECK: llvm.func @concat_blocked
17+
18+
// CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg0[{{.*}}] : !llvm.struct
19+
// CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg1[{{.*}}] : !llvm.struct
20+
// CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg2[{{.*}}] : !llvm.struct
21+
// CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg3[{{.*}}] : !llvm.struct
22+
// CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg4[{{.*}}] : !llvm.struct
23+
// CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg5[{{.*}}] : !llvm.struct
24+
// CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg6[{{.*}}] : !llvm.struct
25+
// CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg7[{{.*}}] : !llvm.struct
26+
27+
// CHECK-COUNT-64: %{{[0-9]*}} = llvm.insertvalue %{{.*}} : !llvm.struct
28+
29+
%1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7:
30+
tensor<32x64xf32, #blocked1>,tensor<32x64xf32, #blocked1>, tensor<32x64xf32, #blocked1>, tensor<32x64xf32, #blocked1>, tensor<32x64xf32, #blocked1>, tensor<32x64xf32, #blocked1>, tensor<32x64xf32, #blocked1>, tensor<32x64xf32, #blocked1> -> tensor<128x128xf32, #blocked1>
31+
tt.return
32+
}
33+
}
34+
35+
// -----
36+
37+
#src_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [64, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}>
38+
#dst_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [0, 128], [64, 0], [128, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}>
39+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
40+
tt.func @concat_ll_2d_1(
41+
%arg0: tensor<128x128xf32, #src_layout>,
42+
%arg1: tensor<128x128xf32, #src_layout>,
43+
%arg2: tensor<128x128xf32, #src_layout>,
44+
%arg3: tensor<128x128xf32, #src_layout>){
45+
// CHECK: llvm.func @concat_ll_2d_1
46+
47+
// CHECK-COUNT-64: %{{.*}} = llvm.extractvalue %arg0[{{.*}}] : !llvm.struct
48+
// CHECK-COUNT-64: %{{.*}} = llvm.extractvalue %arg1[{{.*}}] : !llvm.struct
49+
// CHECK-COUNT-64: %{{.*}} = llvm.extractvalue %arg2[{{.*}}] : !llvm.struct
50+
// CHECK-COUNT-64: %{{.*}} = llvm.extractvalue %arg3[{{.*}}] : !llvm.struct
51+
// CHECK-COUNT-256: %{{.*}} = llvm.insertvalue %{{.*}} : !llvm.struct
52+
53+
%1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3:
54+
tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout> -> tensor<256x256xf32, #dst_layout>
55+
tt.return
56+
}
57+
}
58+
59+
// -----
60+
61+
#src_layout = #ttg.linear<{register=[[1, 0], [2, 0], [4, 0]], lane=[[0, 1], [0, 2], [0, 4], [0, 8], [8, 0], [16, 0]], warp=[[0, 16]], block=[]}>
62+
#dst_layout = #ttg.linear<{register=[[1, 0], [2, 0], [4, 0], [32, 0], [0, 32]], lane=[[0, 1], [0, 2], [0, 4], [0, 8], [8, 0], [16, 0]], warp=[[0, 16]], block=[]}>
63+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
64+
tt.func @concat_ll_2d_2(
65+
%arg0: tensor<32x32xf32, #src_layout>,
66+
%arg1: tensor<32x32xf32, #src_layout>,
67+
%arg2: tensor<32x32xf32, #src_layout>,
68+
%arg3: tensor<32x32xf32, #src_layout>){
69+
// CHECK: llvm.func @concat_ll_2d_2
70+
71+
// CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg0[{{.*}}] : !llvm.struct
72+
// CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg1[{{.*}}] : !llvm.struct
73+
// CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg2[{{.*}}] : !llvm.struct
74+
// CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg3[{{.*}}] : !llvm.struct
75+
// CHECK-COUNT-32: %{{.*}} = llvm.insertvalue %{{.*}} : !llvm.struct
76+
77+
%1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3:
78+
tensor<32x32xf32, #src_layout>, tensor<32x32xf32, #src_layout>, tensor<32x32xf32, #src_layout>, tensor<32x32xf32, #src_layout> -> tensor<64x64xf32, #dst_layout>
79+
tt.return
80+
}
81+
}
82+
83+
// -----
84+
85+
#src_layout = #ttg.linear<{register=[[1]], lane=[[2], [4], [8], [16], [32], [64]], warp=[[128]], block=[]}>
86+
#dst_layout = #ttg.linear<{register=[[1], [256], [512]], lane=[[2], [4], [8], [16], [32], [64]], warp=[[128]], block=[]}>
87+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
88+
tt.func @concat_ll_1d(
89+
%arg0: tensor<256xf32, #src_layout>,
90+
%arg1: tensor<256xf32, #src_layout>,
91+
%arg2: tensor<256xf32, #src_layout>,
92+
%arg3: tensor<256xf32, #src_layout>){
93+
// CHECK: llvm.func @concat_ll_1d
94+
95+
// CHECK-COUNT-2: %{{.*}} = llvm.extractvalue %arg0[{{.*}}] : !llvm.struct
96+
// CHECK-COUNT-2: %{{.*}} = llvm.extractvalue %arg1[{{.*}}] : !llvm.struct
97+
// CHECK-COUNT-2: %{{.*}} = llvm.extractvalue %arg2[{{.*}}] : !llvm.struct
98+
// CHECK-COUNT-2: %{{.*}} = llvm.extractvalue %arg3[{{.*}}] : !llvm.struct
99+
// CHECK-COUNT-8: %{{.*}} = llvm.insertvalue %{{.*}} : !llvm.struct
100+
101+
%1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3:
102+
tensor<256xf32, #src_layout>, tensor<256xf32, #src_layout>, tensor<256xf32, #src_layout>, tensor<256xf32, #src_layout> -> tensor<1024xf32, #dst_layout>
103+
tt.return
104+
}
105+
}

third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,75 @@ def ExtractSliceOp : TT_AMDGPU_Op<"extract_slice", [Pure]> {
119119
let hasVerifier = 1;
120120
}
121121

122+
def ConcatOp : TT_AMDGPU_Op<"concat", [Pure]> {
123+
let summary = "concat operation";
124+
let description = [{
125+
The "concat" operation combines a list of source n-dimensional tensors into a single larger destination tensor.
126+
127+
All source tensors must have the same shape, element type, and encoding.
128+
The concatenation dimension is inferred from the source and destination shapes provided by the user.
129+
For example, two tensors of shape 64x128 can produce a destination shape of 128x128,
130+
indicating concatenation along dimension 0; or 64x256, indicating concatenation along dimension 1.
131+
132+
Generally, source tensors passed as op arguments can be arranged into the resulting shape in multiple ways.
133+
For example, given four tensors of shape 64x64:
134+
concat s0<64x64>, s1<64x64>, s2<64x64>, s3<64x64> -> <128x128>
135+
136+
They can be laid out in different configurations within the result tensor:
137+
1) s0 s1 2) s0 s2
138+
s2 s3 s1 s3
139+
140+
From a logical tensor perspective, the source tensors are treated as elements of a tensor of tensors.
141+
In other words, the 1-D array of input tensors is conceptually reshaped into an n-D grid.
142+
The semantics of this op assume a row-major order (or its n-D generalization),
143+
meaning the fastest-varying dimension is filled first, and the slowest-varying dimension is filled last.
144+
In the example above, this corresponds to layout 1).
145+
146+
The source and destination tensors must have identical linear layouts at the CTA tile level.
147+
That is, all base vectors for input dimensions must match, except for the register input dimension.
148+
The register basis must align on the subset that defines the logical tensor shape of a single CTA tile.
149+
150+
This ensures that the concatenation is a no-op, meaning no data rearrangement among threads is required
151+
to assemble the destination tensor with the given shape and layout.
152+
However, the order of CTA tiles within the layout does not need to match between source and destination layouts.
153+
It is the responsibility of the op's lowering logic to handle this correctly.
154+
155+
This op is designed to work on logical tensors directly, avoiding the need for complex layout reinterpretation or reshaping.
156+
For example, the `tt.join` operation only supports concatenation along the innermost dimension,
157+
and requires that the resulting innermost dimension provide 2 elements per thread, distributed across registers.
158+
In contrast, this `concat` op imposes no constraints on the concatenation dimension or the size of dimensions.
159+
160+
* sources: a list of the input tensors.
161+
162+
Example 1:
163+
164+
```mlir
165+
#blocked = #ttg.blocked<{sizePerThread = [1, 8],
166+
threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
167+
%0 = amdgpu.concat %arg0, %arg1: tensor<32x64xf32, #blocked>,tensor<32x64xf32, #blocked>,
168+
-> tensor<64x64xf32, #blocked>
169+
```
170+
171+
Example 2:
172+
```mlir
173+
#src_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [64, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}>
174+
#dst_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [0, 128], [64, 0], [128, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}>
175+
%0 = amdgpu.concat %arg0, %arg1, %arg2, %arg3 : tensor<128x128xf16, #src_layout>, tensor<128x128xf16, #src_layout>, tensor<128x128xf16, #src_layout>,
176+
tensor<128x128xf16, #src_layout> -> tensor<256x256xf16, #dst_layout>
177+
```
178+
179+
}];
180+
181+
let arguments = (ins Variadic<TT_Tensor>:$sources);
182+
let results = (outs AnyRankedTensor:$result);
183+
184+
let assemblyFormat = [{
185+
$sources attr-dict `:` type($sources) `->` type($result)
186+
}];
187+
188+
let hasVerifier = 1;
189+
}
190+
122191
//===----------------------------------------------------------------------===//
123192
// InstructionSchedHint
124193
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)