Skip to content

Commit 1f4f57e

Browse files
committed
[XPU][TritonIntelGPUToLLVM] Handle special "unbroadcast" case
Unbroadcast is a special case when converting from a slice layout to a non-sliced layout. In particular, when the required transformation simply implies "dropping" some values owned by a given thread as they are already handled by other threads in the same warp. As an example, going from a sliced layout with a single warp in which all elements are owned by all threads to a blocked layout in which elements are distributed across threads is an unbroadcast conversion. Implement this as a sequence of select operations using the sub-group local ID as a discriminator. This actually leads to good codegen in IGC. Additionally, modify analyses so no SLM is allocated for this kind of conversions and thus no barrier is introduced either. Signed-off-by: victor-eds <[email protected]>
1 parent e8b34a0 commit 1f4f57e

File tree

6 files changed

+337
-0
lines changed

6 files changed

+337
-0
lines changed

test/Conversion/intel/intel-allocate-shared-memory.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,19 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
6363
tt.return %0 : tensor<128x64xf32, #blocked1>
6464
}
6565
}
66+
67+
// -----
68+
69+
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
70+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}>
71+
72+
// Check no scrath memory needed for unbroadcast.
73+
74+
// CHECK-LABEL: module attributes
75+
// CHECK-SAME: triton_gpu.shared = 0 : i32
76+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
77+
tt.func @test_basic(%arg0: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<16xf32, #blocked1> {
78+
%0 = triton_gpu.convert_layout %arg0 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16xf32, #blocked1>
79+
tt.return %0 : tensor<16xf32, #blocked1>
80+
}
81+
}
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s
2+
3+
// Basic sub-group unbroadcast.
4+
5+
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
6+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}>
7+
8+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
9+
// CHECK-LABEL: llvm.func spir_kernelcc @test_basic(
10+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
11+
tt.func @test_basic(%arg0: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<16xf32, #blocked1> {
12+
// CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
13+
// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
14+
// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_0]][2] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
15+
// CHECK: %[[VAL_4:.*]] = llvm.extractvalue %[[VAL_0]][3] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
16+
// CHECK: %[[VAL_5:.*]] = llvm.extractvalue %[[VAL_0]][4] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
17+
// CHECK: %[[VAL_6:.*]] = llvm.extractvalue %[[VAL_0]][5] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
18+
// CHECK: %[[VAL_7:.*]] = llvm.extractvalue %[[VAL_0]][6] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
19+
// CHECK: %[[VAL_8:.*]] = llvm.extractvalue %[[VAL_0]][7] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
20+
// CHECK: %[[VAL_9:.*]] = llvm.extractvalue %[[VAL_0]][8] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
21+
// CHECK: %[[VAL_10:.*]] = llvm.extractvalue %[[VAL_0]][9] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
22+
// CHECK: %[[VAL_11:.*]] = llvm.extractvalue %[[VAL_0]][10] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
23+
// CHECK: %[[VAL_12:.*]] = llvm.extractvalue %[[VAL_0]][11] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
24+
// CHECK: %[[VAL_13:.*]] = llvm.extractvalue %[[VAL_0]][12] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
25+
// CHECK: %[[VAL_14:.*]] = llvm.extractvalue %[[VAL_0]][13] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
26+
// CHECK: %[[VAL_15:.*]] = llvm.extractvalue %[[VAL_0]][14] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
27+
// CHECK: %[[VAL_16:.*]] = llvm.extractvalue %[[VAL_0]][15] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
28+
// CHECK: %[[VAL_17:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32
29+
// CHECK: %[[VAL_18:.*]] = llvm.zext %[[VAL_17]] : i32 to i64
30+
// CHECK: %[[VAL_19:.*]] = llvm.trunc %[[VAL_18]] : i64 to i32
31+
// CHECK: %[[VAL_20:.*]] = llvm.mlir.constant(0 : i32) : i32
32+
// CHECK: %[[VAL_21:.*]] = llvm.icmp "eq" %[[VAL_19]], %[[VAL_20]] : i32
33+
// CHECK: %[[VAL_22:.*]] = llvm.mlir.constant(1 : i32) : i32
34+
// CHECK: %[[VAL_23:.*]] = llvm.icmp "eq" %[[VAL_19]], %[[VAL_22]] : i32
35+
// CHECK: %[[VAL_24:.*]] = llvm.mlir.constant(2 : i32) : i32
36+
// CHECK: %[[VAL_25:.*]] = llvm.icmp "eq" %[[VAL_19]], %[[VAL_24]] : i32
37+
// CHECK: %[[VAL_26:.*]] = llvm.mlir.constant(3 : i32) : i32
38+
// CHECK: %[[VAL_27:.*]] = llvm.icmp "eq" %[[VAL_19]], %[[VAL_26]] : i32
39+
// CHECK: %[[VAL_28:.*]] = llvm.mlir.constant(4 : i32) : i32
40+
// CHECK: %[[VAL_29:.*]] = llvm.icmp "eq" %[[VAL_19]], %[[VAL_28]] : i32
41+
// CHECK: %[[VAL_30:.*]] = llvm.mlir.constant(5 : i32) : i32
42+
// CHECK: %[[VAL_31:.*]] = llvm.icmp "eq" %[[VAL_19]], %[[VAL_30]] : i32
43+
// CHECK: %[[VAL_32:.*]] = llvm.mlir.constant(6 : i32) : i32
44+
// CHECK: %[[VAL_33:.*]] = llvm.icmp "eq" %[[VAL_19]], %[[VAL_32]] : i32
45+
// CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(7 : i32) : i32
46+
// CHECK: %[[VAL_35:.*]] = llvm.icmp "eq" %[[VAL_19]], %[[VAL_34]] : i32
47+
// CHECK: %[[VAL_36:.*]] = llvm.mlir.constant(8 : i32) : i32
48+
// CHECK: %[[VAL_37:.*]] = llvm.icmp "eq" %[[VAL_19]], %[[VAL_36]] : i32
49+
// CHECK: %[[VAL_38:.*]] = llvm.mlir.constant(9 : i32) : i32
50+
// CHECK: %[[VAL_39:.*]] = llvm.icmp "eq" %[[VAL_19]], %[[VAL_38]] : i32
51+
// CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(10 : i32) : i32
52+
// CHECK: %[[VAL_41:.*]] = llvm.icmp "eq" %[[VAL_19]], %[[VAL_40]] : i32
53+
// CHECK: %[[VAL_42:.*]] = llvm.mlir.constant(11 : i32) : i32
54+
// CHECK: %[[VAL_43:.*]] = llvm.icmp "eq" %[[VAL_19]], %[[VAL_42]] : i32
55+
// CHECK: %[[VAL_44:.*]] = llvm.mlir.constant(12 : i32) : i32
56+
// CHECK: %[[VAL_45:.*]] = llvm.icmp "eq" %[[VAL_19]], %[[VAL_44]] : i32
57+
// CHECK: %[[VAL_46:.*]] = llvm.mlir.constant(13 : i32) : i32
58+
// CHECK: %[[VAL_47:.*]] = llvm.icmp "eq" %[[VAL_19]], %[[VAL_46]] : i32
59+
// CHECK: %[[VAL_48:.*]] = llvm.mlir.constant(14 : i32) : i32
60+
// CHECK: %[[VAL_49:.*]] = llvm.icmp "eq" %[[VAL_19]], %[[VAL_48]] : i32
61+
// CHECK: %[[VAL_50:.*]] = llvm.mlir.constant(15 : i32) : i32
62+
// CHECK: %[[VAL_51:.*]] = llvm.icmp "eq" %[[VAL_19]], %[[VAL_50]] : i32
63+
// CHECK: %[[VAL_52:.*]] = llvm.mlir.poison : f32
64+
// CHECK: %[[VAL_53:.*]] = llvm.select %[[VAL_21]], %[[VAL_1]], %[[VAL_52]] : i1, f32
65+
// CHECK: %[[VAL_54:.*]] = llvm.select %[[VAL_23]], %[[VAL_2]], %[[VAL_53]] : i1, f32
66+
// CHECK: %[[VAL_55:.*]] = llvm.select %[[VAL_25]], %[[VAL_3]], %[[VAL_54]] : i1, f32
67+
// CHECK: %[[VAL_56:.*]] = llvm.select %[[VAL_27]], %[[VAL_4]], %[[VAL_55]] : i1, f32
68+
// CHECK: %[[VAL_57:.*]] = llvm.select %[[VAL_29]], %[[VAL_5]], %[[VAL_56]] : i1, f32
69+
// CHECK: %[[VAL_58:.*]] = llvm.select %[[VAL_31]], %[[VAL_6]], %[[VAL_57]] : i1, f32
70+
// CHECK: %[[VAL_59:.*]] = llvm.select %[[VAL_33]], %[[VAL_7]], %[[VAL_58]] : i1, f32
71+
// CHECK: %[[VAL_60:.*]] = llvm.select %[[VAL_35]], %[[VAL_8]], %[[VAL_59]] : i1, f32
72+
// CHECK: %[[VAL_61:.*]] = llvm.select %[[VAL_37]], %[[VAL_9]], %[[VAL_60]] : i1, f32
73+
// CHECK: %[[VAL_62:.*]] = llvm.select %[[VAL_39]], %[[VAL_10]], %[[VAL_61]] : i1, f32
74+
// CHECK: %[[VAL_63:.*]] = llvm.select %[[VAL_41]], %[[VAL_11]], %[[VAL_62]] : i1, f32
75+
// CHECK: %[[VAL_64:.*]] = llvm.select %[[VAL_43]], %[[VAL_12]], %[[VAL_63]] : i1, f32
76+
// CHECK: %[[VAL_65:.*]] = llvm.select %[[VAL_45]], %[[VAL_13]], %[[VAL_64]] : i1, f32
77+
// CHECK: %[[VAL_66:.*]] = llvm.select %[[VAL_47]], %[[VAL_14]], %[[VAL_65]] : i1, f32
78+
// CHECK: %[[VAL_67:.*]] = llvm.select %[[VAL_49]], %[[VAL_15]], %[[VAL_66]] : i1, f32
79+
// CHECK: %[[VAL_68:.*]] = llvm.select %[[VAL_51]], %[[VAL_16]], %[[VAL_67]] : i1, f32
80+
// CHECK: %[[VAL_69:.*]] = llvm.mlir.undef : !llvm.struct<(f32)>
81+
// CHECK: %[[VAL_70:.*]] = llvm.insertvalue %[[VAL_68]], %[[VAL_69]][0] : !llvm.struct<(f32)>
82+
%0 = triton_gpu.convert_layout %arg0 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16xf32, #blocked1>
83+
// CHECK: llvm.return %[[VAL_70]] : !llvm.struct<(f32)>
84+
tt.return %0 : tensor<16xf32, #blocked1>
85+
}
86+
}
87+
88+
// -----
89+
90+
// Sub-group unbroadcast with two elements per thread.
91+
92+
#blocked = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
93+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}>
94+
95+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
96+
// CHECK-LABEL: llvm.func spir_kernelcc @test_two_els
97+
tt.func @test_two_els(%arg0: tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32xf32, #blocked1> {
98+
// CHECK: %[[VAL_33:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32
99+
// CHECK: %[[VAL_34:.*]] = llvm.zext %[[VAL_33]] : i32 to i64
100+
// CHECK: %[[VAL_35:.*]] = llvm.trunc %[[VAL_34]] : i64 to i32
101+
// CHECK-COUNT-16: llvm.icmp "eq" %[[VAL_35]], %{{.*}} : i32
102+
// CHECK: llvm.select %{{.*}}, %0, %{{.*}} : i1, f32
103+
// CHECK: llvm.select %{{.*}}, %2, %{{.*}} : i1, f32
104+
// CHECK: llvm.select %{{.*}}, %4, %{{.*}} : i1, f32
105+
// CHECK: llvm.select %{{.*}}, %6, %{{.*}} : i1, f32
106+
// CHECK: llvm.select %{{.*}}, %8, %{{.*}} : i1, f32
107+
// CHECK: llvm.select %{{.*}}, %10, %{{.*}} : i1, f32
108+
// CHECK: llvm.select %{{.*}}, %12, %{{.*}} : i1, f32
109+
// CHECK: llvm.select %{{.*}}, %14, %{{.*}} : i1, f32
110+
// CHECK: llvm.select %{{.*}}, %16, %{{.*}} : i1, f32
111+
// CHECK: llvm.select %{{.*}}, %18, %{{.*}} : i1, f32
112+
// CHECK: llvm.select %{{.*}}, %20, %{{.*}} : i1, f32
113+
// CHECK: llvm.select %{{.*}}, %22, %{{.*}} : i1, f32
114+
// CHECK: llvm.select %{{.*}}, %24, %{{.*}} : i1, f32
115+
// CHECK: llvm.select %{{.*}}, %26, %{{.*}} : i1, f32
116+
// CHECK: llvm.select %{{.*}}, %28, %{{.*}} : i1, f32
117+
// CHECK: llvm.select %{{.*}}, %30, %{{.*}} : i1, f32
118+
// CHECK: llvm.select %{{.*}}, %1, %{{.*}} : i1, f32
119+
// CHECK: llvm.select %{{.*}}, %3, %{{.*}} : i1, f32
120+
// CHECK: llvm.select %{{.*}}, %5, %{{.*}} : i1, f32
121+
// CHECK: llvm.select %{{.*}}, %7, %{{.*}} : i1, f32
122+
// CHECK: llvm.select %{{.*}}, %9, %{{.*}} : i1, f32
123+
// CHECK: llvm.select %{{.*}}, %11, %{{.*}} : i1, f32
124+
// CHECK: llvm.select %{{.*}}, %13, %{{.*}} : i1, f32
125+
// CHECK: llvm.select %{{.*}}, %15, %{{.*}} : i1, f32
126+
// CHECK: llvm.select %{{.*}}, %17, %{{.*}} : i1, f32
127+
// CHECK: llvm.select %{{.*}}, %19, %{{.*}} : i1, f32
128+
// CHECK: llvm.select %{{.*}}, %21, %{{.*}} : i1, f32
129+
// CHECK: llvm.select %{{.*}}, %23, %{{.*}} : i1, f32
130+
// CHECK: llvm.select %{{.*}}, %25, %{{.*}} : i1, f32
131+
// CHECK: llvm.select %{{.*}}, %27, %{{.*}} : i1, f32
132+
// CHECK: llvm.select %{{.*}}, %29, %{{.*}} : i1, f32
133+
// CHECK: llvm.select %{{.*}}, %31, %{{.*}} : i1, f32
134+
// COM: We return a struct with two values: go from 32 elements per thread to just 2.
135+
// CHECK: %[[VAL_101:.*]] = llvm.mlir.undef : !llvm.struct<(f32, f32)>
136+
// CHECK: %[[VAL_102:.*]] = llvm.insertvalue %{{.*}}, %[[VAL_101]][0] : !llvm.struct<(f32, f32)>
137+
// CHECK: %[[VAL_103:.*]] = llvm.insertvalue %{{.*}}, %[[VAL_102]][1] : !llvm.struct<(f32, f32)>
138+
%0 = triton_gpu.convert_layout %arg0 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1>
139+
// CHECK: llvm.return %[[VAL_103]] : !llvm.struct<(f32, f32)>
140+
tt.return %0 : tensor<32xf32, #blocked1>
141+
}
142+
}
143+
144+
// -----
145+
146+
// Sub-group unbroadcast with four elements per thread and 4 warps.
147+
148+
#blocked = #triton_gpu.blocked<{sizePerThread = [64, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 1], order = [0, 1]}>
149+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [16], warpsPerCTA = [4], order = [0]}>
150+
151+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
152+
// CHECK-LABEL: llvm.func spir_kernelcc @test_four_els_four_warps(
153+
tt.func @test_four_els_four_warps(%arg0: tensor<256xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<256xf32, #blocked1> {
154+
// CHECK: %[[VAL_33:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32
155+
// CHECK: %[[VAL_34:.*]] = llvm.zext %[[VAL_33]] : i32 to i64
156+
// CHECK: %[[VAL_35:.*]] = llvm.trunc %[[VAL_34]] : i64 to i32
157+
// CHECK-COUNT-16: llvm.icmp "eq" %[[VAL_35]], %{{.*}} : i32
158+
// CHECK-COUNT-64: llvm.select
159+
// COM: We return a struct with four values: go from 64 elements per thread to just 4.
160+
// CHECK: %[[VAL_165:.*]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32)>
161+
// CHECK: %[[VAL_166:.*]] = llvm.insertvalue %{{.*}}, %[[VAL_165]][0] : !llvm.struct<(f32, f32, f32, f32)>
162+
// CHECK: %[[VAL_167:.*]] = llvm.insertvalue %{{.*}}, %[[VAL_166]][1] : !llvm.struct<(f32, f32, f32, f32)>
163+
// CHECK: %[[VAL_168:.*]] = llvm.insertvalue %{{.*}}, %[[VAL_167]][2] : !llvm.struct<(f32, f32, f32, f32)>
164+
// CHECK: %[[VAL_169:.*]] = llvm.insertvalue %{{.*}}, %[[VAL_168]][3] : !llvm.struct<(f32, f32, f32, f32)>
165+
%0 = triton_gpu.convert_layout %arg0 : tensor<256xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<256xf32, #blocked1>
166+
// CHECK: llvm.return %[[VAL_169]] : !llvm.struct<(f32, f32, f32, f32)>
167+
tt.return %0 : tensor<256xf32, #blocked1>
168+
}
169+
}
170+
171+
// -----
172+
173+
// Sub-group unbroadcast with two elements per thread, but repeated layout.
174+
175+
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
176+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}>
177+
178+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
179+
// CHECK-LABEL: llvm.func spir_kernelcc @test_two_els_with_repeat(
180+
tt.func @test_two_els_with_repeat(%arg0: tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32xf32, #blocked1> {
181+
// CHECK: %[[VAL_33:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32
182+
// CHECK: %[[VAL_34:.*]] = llvm.zext %[[VAL_33]] : i32 to i64
183+
// CHECK: %[[VAL_35:.*]] = llvm.trunc %[[VAL_34]] : i64 to i32
184+
// CHECK-COUNT-16: llvm.icmp "eq" %[[VAL_35]], %{{.*}} : i32
185+
// COM: Check select order differs from above as we have a single element per thread.
186+
// CHECK: llvm.select %{{.*}}, %0, %{{.*}} : i1, f32
187+
// CHECK: llvm.select %{{.*}}, %1, %{{.*}} : i1, f32
188+
// CHECK: llvm.select %{{.*}}, %2, %{{.*}} : i1, f32
189+
// CHECK: llvm.select %{{.*}}, %3, %{{.*}} : i1, f32
190+
// CHECK: llvm.select %{{.*}}, %4, %{{.*}} : i1, f32
191+
// CHECK: llvm.select %{{.*}}, %5, %{{.*}} : i1, f32
192+
// CHECK: llvm.select %{{.*}}, %6, %{{.*}} : i1, f32
193+
// CHECK: llvm.select %{{.*}}, %7, %{{.*}} : i1, f32
194+
// CHECK: llvm.select %{{.*}}, %8, %{{.*}} : i1, f32
195+
// CHECK: llvm.select %{{.*}}, %9, %{{.*}} : i1, f32
196+
// CHECK: llvm.select %{{.*}}, %10, %{{.*}} : i1, f32
197+
// CHECK: llvm.select %{{.*}}, %11, %{{.*}} : i1, f32
198+
// CHECK: llvm.select %{{.*}}, %12, %{{.*}} : i1, f32
199+
// CHECK: llvm.select %{{.*}}, %13, %{{.*}} : i1, f32
200+
// CHECK: llvm.select %{{.*}}, %14, %{{.*}} : i1, f32
201+
// CHECK: llvm.select %{{.*}}, %15, %{{.*}} : i1, f32
202+
// CHECK: llvm.select %{{.*}}, %1, %{{.*}} : i1, f32
203+
// CHECK: llvm.select %{{.*}}, %2, %{{.*}} : i1, f32
204+
// CHECK: llvm.select %{{.*}}, %3, %{{.*}} : i1, f32
205+
// CHECK: llvm.select %{{.*}}, %4, %{{.*}} : i1, f32
206+
// CHECK: llvm.select %{{.*}}, %5, %{{.*}} : i1, f32
207+
// CHECK: llvm.select %{{.*}}, %6, %{{.*}} : i1, f32
208+
// CHECK: llvm.select %{{.*}}, %7, %{{.*}} : i1, f32
209+
// CHECK: llvm.select %{{.*}}, %8, %{{.*}} : i1, f32
210+
// CHECK: llvm.select %{{.*}}, %9, %{{.*}} : i1, f32
211+
// CHECK: llvm.select %{{.*}}, %10, %{{.*}} : i1, f32
212+
// CHECK: llvm.select %{{.*}}, %11, %{{.*}} : i1, f32
213+
// CHECK: llvm.select %{{.*}}, %12, %{{.*}} : i1, f32
214+
// CHECK: llvm.select %{{.*}}, %13, %{{.*}} : i1, f32
215+
// CHECK: llvm.select %{{.*}}, %14, %{{.*}} : i1, f32
216+
// CHECK: llvm.select %{{.*}}, %15, %{{.*}} : i1, f32
217+
// CHECK: llvm.select %{{.*}}, %16, %{{.*}} : i1, f32
218+
// COM: We return a struct with two values: go from 32 elements per thread to just 2.
219+
// CHECK: %[[VAL_165:.*]] = llvm.mlir.undef : !llvm.struct<(f32, f32)>
220+
// CHECK: %[[VAL_166:.*]] = llvm.insertvalue %{{.*}}, %[[VAL_165]][0] : !llvm.struct<(f32, f32)>
221+
// CHECK: %[[VAL_167:.*]] = llvm.insertvalue %{{.*}}, %[[VAL_166]][1] : !llvm.struct<(f32, f32)>
222+
%0 = triton_gpu.convert_layout %arg0 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1>
223+
// CHECK: llvm.return %[[VAL_167]] : !llvm.struct<(f32, f32)>
224+
tt.return %0 : tensor<32xf32, #blocked1>
225+
}
226+
}

third_party/intel/include/Analysis/Utility.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ bool cvtIsSubGroupTranspose(RankedTensorType srcTy, RankedTensorType dstTy);
1616
/// Return whether `type` is a valid element type for a fast sub-group
1717
/// transpose.
1818
bool isValidElementTypeForSubGroupTranspose(Type type);
19+
/// Return whether the layout conversion from `srcTy` to `dstTy` can be
20+
/// performed as an "unbroadcast" transformation, i.e., dropping duplicated
21+
/// values.
22+
bool cvtIsUnbroadcast(RankedTensorType srcTy, RankedTensorType dstTy);
1923

2024
} // namespace mlir::triton::gpu::intel
2125

third_party/intel/lib/Analysis/Allocation.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,11 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
112112
return ScratchConfig({}, {});
113113
}
114114

115+
if (gpu::intel::cvtIsUnbroadcast(srcTy, dstTy)) {
116+
// Conversions identified as "unbroadcast" do not need scratch memory.
117+
return ScratchConfig({}, {});
118+
}
119+
115120
if (gpu::intel::cvtIsSubGroupTranspose(srcTy, dstTy)) {
116121
// Conversions that can be implemented as sub-group transposes store the
117122
// whole tensor in shared memory and read it afterwards.

0 commit comments

Comments
 (0)