Skip to content

Commit a2a9601

Browse files
authored
[mlir][AMDGPU] Updated PermlaneSwapOp to select correct val (#157586)
* as per the instruction description, updated `PermlaneSwapOp` to select correct val * updated corresponding lit tests Issue it resolves: the block reduction was failing otherwise as we were selecting the `{0}` always. --------- Signed-off-by: xintin <[email protected]>
1 parent 2832717 commit a2a9601

File tree

3 files changed

+91
-31
lines changed

3 files changed

+91
-31
lines changed

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1915,7 +1915,16 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
19151915
else
19161916
llvm_unreachable("unsupported row length");
19171917

1918-
Value vdstNew = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
1918+
const Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
1919+
const Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
1920+
1921+
const Value isEqual =
1922+
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, vdst0, v);
1923+
1924+
// Per `permlane(16|32)` semantics: if the first extracted element equals
1925+
// 'v', the result is the second element; otherwise it is the first.
1926+
Value vdstNew =
1927+
rewriter.create<LLVM::SelectOp>(loc, isEqual, vdst1, vdst0);
19191928
permuted.emplace_back(vdstNew);
19201929
}
19211930

mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir

100644100755
Lines changed: 71 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
// CHECK-SAME: (%[[ARG0:.*]]: i32)
55
func.func @test_permlane16_i32(%arg0 : i32) -> i32 {
66
// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ARG0]], %[[ARG0]], false, false : (i32, i32) -> <(i32, i32)>
7-
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
8-
// CHECK: return %[[RES]] : i32
7+
// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
8+
// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
9+
// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[ARG0]] : i32
10+
// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
11+
// CHECK: return %[[SEL]] : i32
912
%0 = amdgpu.permlane_swap %arg0 16 : i32
1013
return %0 : i32
1114
}
@@ -14,8 +17,11 @@ func.func @test_permlane16_i32(%arg0 : i32) -> i32 {
1417
// CHECK-SAME: (%[[ARG0:.*]]: i32)
1518
func.func @test_permlane16_i32_optional_attr(%arg0 : i32) -> i32 {
1619
// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ARG0]], %[[ARG0]], true, true : (i32, i32) -> <(i32, i32)>
17-
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
18-
// CHECK: return %[[RES]] : i32
20+
// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
21+
// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
22+
// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[ARG0]] : i32
23+
// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
24+
// CHECK: return %[[SEL]] : i32
1925
%0 = amdgpu.permlane_swap %arg0 16 { fetch_inactive = true, bound_ctrl = true } : i32
2026
return %0 : i32
2127
}
@@ -24,8 +30,11 @@ func.func @test_permlane16_i32_optional_attr(%arg0 : i32) -> i32 {
2430
// CHECK-SAME: (%[[ARG0:.*]]: i32)
2531
func.func @test_permlane32_i32(%arg0 : i32) -> i32 {
2632
// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[ARG0]], %[[ARG0]], false, false : (i32, i32) -> <(i32, i32)>
27-
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
28-
// CHECK: return %[[RES]] : i32
33+
// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
34+
// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
35+
// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[ARG0]] : i32
36+
// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
37+
// CHECK: return %[[SEL]] : i32
2938
%0 = amdgpu.permlane_swap %arg0 32 : i32
3039
return %0 : i32
3140
}
@@ -35,8 +44,11 @@ func.func @test_permlane32_i32(%arg0 : i32) -> i32 {
3544
func.func @test_permlane16_f32(%arg0 : f32) -> f32 {
3645
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
3746
// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[CAST]], %[[CAST]], false, false : (i32, i32) -> <(i32, i32)>
38-
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
39-
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
47+
// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
48+
// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
49+
// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[CAST]] : i32
50+
// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
51+
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[SEL]] : i32 to f32
4052
// CHECK: return %[[RES_CAST]] : f32
4153
%0 = amdgpu.permlane_swap %arg0 16 : f32
4254
return %0 : f32
@@ -47,8 +59,11 @@ func.func @test_permlane16_f32(%arg0 : f32) -> f32 {
4759
func.func @test_permlane32_f32(%arg0 : f32) -> f32 {
4860
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
4961
// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[CAST]], %[[CAST]], false, false : (i32, i32) -> <(i32, i32)>
50-
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
51-
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
62+
// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
63+
// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
64+
// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[CAST]] : i32
65+
// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
66+
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[SEL]] : i32 to f32
5267
// CHECK: return %[[RES_CAST]] : f32
5368
%0 = amdgpu.permlane_swap %arg0 32 : f32
5469
return %0 : f32
@@ -60,8 +75,11 @@ func.func @test_permlane16_f16(%arg0 : f16) -> f16 {
6075
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
6176
// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
6277
// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ZEXT]], %[[ZEXT]], false, false : (i32, i32) -> <(i32, i32)>
63-
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
64-
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
78+
// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
79+
// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
80+
// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[ZEXT]] : i32
81+
// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
82+
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SEL]] : i32 to i16
6583
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
6684
// CHECK: return %[[RES_CAST]] : f16
6785
%0 = amdgpu.permlane_swap %arg0 16 : f16
@@ -74,8 +92,11 @@ func.func @test_permlane32_f16(%arg0 : f16) -> f16 {
7492
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
7593
// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
7694
// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[ZEXT]], %[[ZEXT]], false, false : (i32, i32) -> <(i32, i32)>
77-
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
78-
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
95+
// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
96+
// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
97+
// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[ZEXT]] : i32
98+
// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
99+
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SEL]] : i32 to i16
79100
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
80101
// CHECK: return %[[RES_CAST]] : f16
81102
%0 = amdgpu.permlane_swap %arg0 32 : f16
@@ -90,10 +111,16 @@ func.func @test_permlane16_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
90111
// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
91112
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
92113
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
93-
// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
94-
// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
95-
// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
96-
// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
114+
// CHECK: %[[T0:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
115+
// CHECK: %[[T0_0:.*]] = llvm.extractvalue %[[T0]][0] : !llvm.struct<(i32, i32)>
116+
// CHECK: %[[T0_1:.*]] = llvm.extractvalue %[[T0]][1] : !llvm.struct<(i32, i32)>
117+
// CHECK: %[[CMP0:.*]] = llvm.icmp "eq" %[[T0_0]], %[[ELEM0]] : i32
118+
// CHECK: %[[PERM0:.*]] = llvm.select %[[CMP0]], %[[T0_1]], %[[T0_0]] : i1, i32
119+
// CHECK: %[[T1:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
120+
// CHECK: %[[T1_0:.*]] = llvm.extractvalue %[[T1]][0] : !llvm.struct<(i32, i32)>
121+
// CHECK: %[[T1_1:.*]] = llvm.extractvalue %[[T1]][1] : !llvm.struct<(i32, i32)>
122+
// CHECK: %[[CMP1:.*]] = llvm.icmp "eq" %[[T1_0]], %[[ELEM1]] : i32
123+
// CHECK: %[[PERM1:.*]] = llvm.select %[[CMP1]], %[[T1_1]], %[[T1_0]] : i1, i32
97124
// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
98125
// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
99126
// CHECK: return %[[VEC_INSERT1]] : vector<2xi32>
@@ -109,10 +136,16 @@ func.func @test_permlane32_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
109136
// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
110137
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
111138
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
112-
// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
113-
// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
114-
// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
115-
// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
139+
// CHECK: %[[T0:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
140+
// CHECK: %[[T0_0:.*]] = llvm.extractvalue %[[T0]][0] : !llvm.struct<(i32, i32)>
141+
// CHECK: %[[T0_1:.*]] = llvm.extractvalue %[[T0]][1] : !llvm.struct<(i32, i32)>
142+
// CHECK: %[[CMP0:.*]] = llvm.icmp "eq" %[[T0_0]], %[[ELEM0]] : i32
143+
// CHECK: %[[PERM0:.*]] = llvm.select %[[CMP0]], %[[T0_1]], %[[T0_0]] : i1, i32
144+
// CHECK: %[[T1:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
145+
// CHECK: %[[T1_0:.*]] = llvm.extractvalue %[[T1]][0] : !llvm.struct<(i32, i32)>
146+
// CHECK: %[[T1_1:.*]] = llvm.extractvalue %[[T1]][1] : !llvm.struct<(i32, i32)>
147+
// CHECK: %[[CMP1:.*]] = llvm.icmp "eq" %[[T1_0]], %[[ELEM1]] : i32
148+
// CHECK: %[[PERM1:.*]] = llvm.select %[[CMP1]], %[[T1_1]], %[[T1_0]] : i1, i32
116149
// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
117150
// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
118151
// CHECK: return %[[VEC_INSERT1]] : vector<2xi32>
@@ -130,9 +163,15 @@ func.func @test_permlane16_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
130163
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
131164
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
132165
// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
133-
// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
166+
// CHECK: %[[PERM0_E0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
167+
// CHECK: %[[PERM0_E1:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][1] : !llvm.struct<(i32, i32)>
168+
// CHECK: %[[CMP0:.*]] = llvm.icmp "eq" %[[PERM0_E0]], %[[ELEM0]] : i32
169+
// CHECK: %[[PERM0:.*]] = llvm.select %[[CMP0]], %[[PERM0_E1]], %[[PERM0_E0]] : i1, i32
134170
// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
135-
// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
171+
// CHECK: %[[PERM1_E0:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
172+
// CHECK: %[[PERM1_E1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][1] : !llvm.struct<(i32, i32)>
173+
// CHECK: %[[CMP1:.*]] = llvm.icmp "eq" %[[PERM1_E0]], %[[ELEM1]] : i32
174+
// CHECK: %[[PERM1:.*]] = llvm.select %[[CMP1]], %[[PERM1_E1]], %[[PERM1_E0]] : i1, i32
136175
// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
137176
// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
138177
// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[VEC_INSERT1]] : vector<2xi32> to vector<4xf16>
@@ -151,9 +190,15 @@ func.func @test_permlane32_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
151190
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
152191
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
153192
// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
154-
// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
193+
// CHECK: %[[PERM0_E0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
194+
// CHECK: %[[PERM0_E1:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][1] : !llvm.struct<(i32, i32)>
195+
// CHECK: %[[CMP0:.*]] = llvm.icmp "eq" %[[PERM0_E0]], %[[ELEM0]] : i32
196+
// CHECK: %[[PERM0:.*]] = llvm.select %[[CMP0]], %[[PERM0_E1]], %[[PERM0_E0]] : i1, i32
155197
// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
156-
// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
198+
// CHECK: %[[PERM1_E0:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
199+
// CHECK: %[[PERM1_E1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][1] : !llvm.struct<(i32, i32)>
200+
// CHECK: %[[CMP1:.*]] = llvm.icmp "eq" %[[PERM1_E0]], %[[ELEM1]] : i32
201+
// CHECK: %[[PERM1:.*]] = llvm.select %[[CMP1]], %[[PERM1_E1]], %[[PERM1_E0]] : i1, i32
157202
// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
158203
// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
159204
// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[VEC_INSERT1]] : vector<2xi32> to vector<4xf16>

mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir

100644100755
Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -749,13 +749,19 @@ gpu.module @test_module {
749749
%shfl1, %pred1 = gpu.shuffle xor %arg0, %arg1, %arg4 : f32
750750
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
751751
// CHECK: %[[#PERMUTE:]] = rocdl.permlane16.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)>
752-
// CHECK: %[[#EXTRACT:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
753-
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#EXTRACT]] : i32 to f32
752+
// CHECK: %[[#EXTRACT0:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
753+
// CHECK: %[[#EXTRACT1:]] = llvm.extractvalue %[[#PERMUTE:]][1] : !llvm.struct<(i32, i32)>
754+
// CHECK: %[[#CMP:]] = llvm.icmp "eq" %[[#EXTRACT0]], %[[#CAST_VALUE]] : i32
755+
// CHECK: %[[#SEL:]] = llvm.select %[[#CMP]], %[[#EXTRACT1]], %[[#EXTRACT0]] : i1, i32
756+
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#SEL]] : i32 to f32
754757
%shfl2, %pred2 = gpu.shuffle xor %arg0, %arg2, %arg4 : f32
755758
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
756759
// CHECK: %[[#PERMUTE:]] = rocdl.permlane32.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)>
757-
// CHECK: %[[#EXTRACT:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
758-
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#EXTRACT]] : i32 to f32
760+
// CHECK: %[[#EXTRACT0:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
761+
// CHECK: %[[#EXTRACT1:]] = llvm.extractvalue %[[#PERMUTE:]][1] : !llvm.struct<(i32, i32)>
762+
// CHECK: %[[#CMP:]] = llvm.icmp "eq" %[[#EXTRACT0]], %[[#CAST_VALUE]] : i32
763+
// CHECK: %[[#SEL:]] = llvm.select %[[#CMP]], %[[#EXTRACT1]], %[[#EXTRACT0]] : i1, i32
764+
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#SEL]] : i32 to f32
759765
%shfl3, %pred3 = gpu.shuffle xor %arg0, %arg3, %arg4 : f32
760766
func.return %shfl1, %shfl2, %shfl3 : f32, f32, f32
761767
}

0 commit comments

Comments
 (0)