Skip to content

Commit ef8322f

Browse files
authored
[MLIR][LLVM] Improve bit- and addrspacecast folders (llvm#87745)
This commit extends the folders of chainable casts (bitcast and addrspacecast) to ensure that they fold a chain of the same casts into a single cast. Additionally cleans up the canonicalization test file, as this used some outdated constructs.
1 parent 974f1ee commit ef8322f

File tree

3 files changed

+74
-32
lines changed

3 files changed

+74
-32
lines changed

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2761,17 +2761,27 @@ LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); }
27612761
// Folder and verifier for LLVM::BitcastOp
27622762
//===----------------------------------------------------------------------===//
27632763

2764-
OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) {
2765-
// bitcast(x : T0, T0) -> x
2766-
if (getArg().getType() == getType())
2767-
return getArg();
2768-
// bitcast(bitcast(x : T0, T1), T0) -> x
2769-
if (auto prev = getArg().getDefiningOp<BitcastOp>())
2770-
if (prev.getArg().getType() == getType())
2764+
/// Folds a cast op that can be chained.
2765+
template <typename T>
2766+
static Value foldChainableCast(T castOp, typename T::FoldAdaptor adaptor) {
2767+
// cast(x : T0, T0) -> x
2768+
if (castOp.getArg().getType() == castOp.getType())
2769+
return castOp.getArg();
2770+
if (auto prev = castOp.getArg().template getDefiningOp<T>()) {
2771+
// cast(cast(x : T0, T1), T0) -> x
2772+
if (prev.getArg().getType() == castOp.getType())
27712773
return prev.getArg();
2774+
// cast(cast(x : T0, T1), T2) -> cast(x: T0, T2)
2775+
castOp.getArgMutable().set(prev.getArg());
2776+
return Value{castOp};
2777+
}
27722778
return {};
27732779
}
27742780

2781+
OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) {
2782+
return foldChainableCast(*this, adaptor);
2783+
}
2784+
27752785
LogicalResult LLVM::BitcastOp::verify() {
27762786
auto resultType = llvm::dyn_cast<LLVMPointerType>(
27772787
extractVectorElementType(getResult().getType()));
@@ -2811,14 +2821,7 @@ LogicalResult LLVM::BitcastOp::verify() {
28112821
//===----------------------------------------------------------------------===//
28122822

28132823
OpFoldResult LLVM::AddrSpaceCastOp::fold(FoldAdaptor adaptor) {
2814-
// addrcast(x : T0, T0) -> x
2815-
if (getArg().getType() == getType())
2816-
return getArg();
2817-
// addrcast(addrcast(x : T0, T1), T0) -> x
2818-
if (auto prev = getArg().getDefiningOp<AddrSpaceCastOp>())
2819-
if (prev.getArg().getType() == getType())
2820-
return prev.getArg();
2821-
return {};
2824+
return foldChainableCast(*this, adaptor);
28222825
}
28232826

28242827
//===----------------------------------------------------------------------===//

mlir/test/Dialect/LLVMIR/canonicalize.mlir

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ llvm.func @fold_icmp_eq(%arg0 : i32) -> i1 {
88
llvm.return %0 : i1
99
}
1010

11+
// -----
12+
1113
// CHECK-LABEL: @fold_icmp_ne
1214
llvm.func @fold_icmp_ne(%arg0 : vector<2xi32>) -> vector<2xi1> {
1315
// CHECK: %[[C0:.*]] = llvm.mlir.constant(dense<false> : vector<2xi1>) : vector<2xi1>
@@ -16,6 +18,8 @@ llvm.func @fold_icmp_ne(%arg0 : vector<2xi32>) -> vector<2xi1> {
1618
llvm.return %0 : vector<2xi1>
1719
}
1820

21+
// -----
22+
1923
// CHECK-LABEL: @fold_icmp_alloca
2024
llvm.func @fold_icmp_alloca() -> i1 {
2125
// CHECK: %[[C0:.*]] = llvm.mlir.constant(true) : i1
@@ -83,16 +87,18 @@ llvm.func @fold_unrelated_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 {
8387
// -----
8488

8589
// CHECK-LABEL: fold_bitcast
86-
// CHECK-SAME: %[[a0:arg[0-9]+]]
87-
// CHECK-NEXT: llvm.return %[[a0]]
90+
// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
91+
// CHECK-NEXT: llvm.return %[[ARG]]
8892
llvm.func @fold_bitcast(%x : !llvm.ptr) -> !llvm.ptr {
8993
%c = llvm.bitcast %x : !llvm.ptr to !llvm.ptr
9094
llvm.return %c : !llvm.ptr
9195
}
9296

97+
// -----
98+
9399
// CHECK-LABEL: fold_bitcast2
94-
// CHECK-SAME: %[[a0:arg[0-9]+]]
95-
// CHECK-NEXT: llvm.return %[[a0]]
100+
// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
101+
// CHECK-NEXT: llvm.return %[[ARG]]
96102
llvm.func @fold_bitcast2(%x : i32) -> i32 {
97103
%c = llvm.bitcast %x : i32 to f32
98104
%d = llvm.bitcast %c : f32 to i32
@@ -101,17 +107,31 @@ llvm.func @fold_bitcast2(%x : i32) -> i32 {
101107

102108
// -----
103109

110+
// CHECK-LABEL: fold_bitcast_chain
111+
// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
112+
llvm.func @fold_bitcast_chain(%x : i32) -> vector<2xi16> {
113+
%c = llvm.bitcast %x : i32 to f32
114+
%d = llvm.bitcast %c : f32 to vector<2xi16>
115+
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : i32 to vector<2xi16>
116+
// CHECK: llvm.return %[[BITCAST]]
117+
llvm.return %d : vector<2xi16>
118+
}
119+
120+
// -----
121+
104122
// CHECK-LABEL: fold_addrcast
105-
// CHECK-SAME: %[[a0:arg[0-9]+]]
106-
// CHECK-NEXT: llvm.return %[[a0]]
123+
// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
124+
// CHECK-NEXT: llvm.return %[[ARG]]
107125
llvm.func @fold_addrcast(%x : !llvm.ptr) -> !llvm.ptr {
108126
%c = llvm.addrspacecast %x : !llvm.ptr to !llvm.ptr
109127
llvm.return %c : !llvm.ptr
110128
}
111129

130+
// -----
131+
112132
// CHECK-LABEL: fold_addrcast2
113-
// CHECK-SAME: %[[a0:arg[0-9]+]]
114-
// CHECK-NEXT: llvm.return %[[a0]]
133+
// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
134+
// CHECK-NEXT: llvm.return %[[ARG]]
115135
llvm.func @fold_addrcast2(%x : !llvm.ptr) -> !llvm.ptr {
116136
%c = llvm.addrspacecast %x : !llvm.ptr to !llvm.ptr<5>
117137
%d = llvm.addrspacecast %c : !llvm.ptr<5> to !llvm.ptr
@@ -120,28 +140,44 @@ llvm.func @fold_addrcast2(%x : !llvm.ptr) -> !llvm.ptr {
120140

121141
// -----
122142

143+
// CHECK-LABEL: fold_addrcast_chain
144+
// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
145+
llvm.func @fold_addrcast_chain(%x : !llvm.ptr) -> !llvm.ptr<2> {
146+
%c = llvm.addrspacecast %x : !llvm.ptr to !llvm.ptr<1>
147+
%d = llvm.addrspacecast %c : !llvm.ptr<1> to !llvm.ptr<2>
148+
// CHECK: %[[ADDRCAST:.*]] = llvm.addrspacecast %[[ARG]] : !llvm.ptr to !llvm.ptr<2>
149+
// CHECK: llvm.return %[[ADDRCAST]]
150+
llvm.return %d : !llvm.ptr<2>
151+
}
152+
153+
// -----
154+
123155
// CHECK-LABEL: fold_gep
124-
// CHECK-SAME: %[[a0:arg[0-9]+]]
125-
// CHECK-NEXT: llvm.return %[[a0]]
156+
// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
157+
// CHECK-NEXT: llvm.return %[[ARG]]
126158
llvm.func @fold_gep(%x : !llvm.ptr) -> !llvm.ptr {
127159
%c0 = arith.constant 0 : i32
128160
%c = llvm.getelementptr %x[%c0] : (!llvm.ptr, i32) -> !llvm.ptr, i8
129161
llvm.return %c : !llvm.ptr
130162
}
131163

164+
// -----
165+
132166
// CHECK-LABEL: fold_gep_neg
133-
// CHECK-SAME: %[[a0:arg[0-9]+]]
134-
// CHECK-NEXT: %[[RES:.*]] = llvm.getelementptr inbounds %[[a0]][0, 1]
167+
// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
168+
// CHECK-NEXT: %[[RES:.*]] = llvm.getelementptr inbounds %[[ARG]][0, 1]
135169
// CHECK-NEXT: llvm.return %[[RES]]
136170
llvm.func @fold_gep_neg(%x : !llvm.ptr) -> !llvm.ptr {
137171
%c0 = arith.constant 0 : i32
138172
%0 = llvm.getelementptr inbounds %x[%c0, 1] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<(i32, i32)>
139173
llvm.return %0 : !llvm.ptr
140174
}
141175

176+
// -----
177+
142178
// CHECK-LABEL: fold_gep_canon
143-
// CHECK-SAME: %[[a0:arg[0-9]+]]
144-
// CHECK-NEXT: %[[RES:.*]] = llvm.getelementptr %[[a0]][2]
179+
// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
180+
// CHECK-NEXT: %[[RES:.*]] = llvm.getelementptr %[[ARG]][2]
145181
// CHECK-NEXT: llvm.return %[[RES]]
146182
llvm.func @fold_gep_canon(%x : !llvm.ptr) -> !llvm.ptr {
147183
%c2 = arith.constant 2 : i32
@@ -175,6 +211,8 @@ llvm.func @load_dce(%x : !llvm.ptr) {
175211
llvm.return
176212
}
177213

214+
// -----
215+
178216
llvm.mlir.global external @fp() : !llvm.ptr
179217

180218
// CHECK-LABEL: addr_dce
@@ -184,6 +222,8 @@ llvm.func @addr_dce(%x : !llvm.ptr) {
184222
llvm.return
185223
}
186224

225+
// -----
226+
187227
// CHECK-LABEL: alloca_dce
188228
// CHECK-NEXT: llvm.return
189229
llvm.func @alloca_dce() {

mlir/test/Dialect/LLVMIR/mem2reg.mlir

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -793,9 +793,8 @@ llvm.func @store_int_to_vector(%arg: i32) -> vector<4xi8> {
793793
%1 = llvm.alloca %0 x vector<2xi16> {alignment = 4 : i64} : (i32) -> !llvm.ptr
794794
llvm.store %arg, %1 {alignment = 4 : i64} : i32, !llvm.ptr
795795
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> vector<4xi8>
796-
// CHECK: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG]] : i32 to vector<2xi16>
797-
// CHECK: %[[BITCAST1:.*]] = llvm.bitcast %[[BITCAST0]] : vector<2xi16> to vector<4xi8>
798-
// CHECK: llvm.return %[[BITCAST1]]
796+
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : i32 to vector<4xi8>
797+
// CHECK: llvm.return %[[BITCAST]]
799798
llvm.return %2 : vector<4xi8>
800799
}
801800

0 commit comments

Comments
 (0)