Skip to content

Commit a2a31c1

Browse files
AmrDeveloperlanza
authored andcommitted
[CIR] Backport folder implementation for VecExtractOp (llvm#1613)
Backport llvm#139304
1 parent ed25119 commit a2a31c1

File tree

4 files changed

+67
-7
lines changed

4 files changed

+67
-7
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3165,6 +3165,7 @@ def VecExtractOp : CIR_Op<"vec.extract", [Pure,
31653165
}];
31663166

31673167
let llvmOp = "ExtractElementOp";
3168+
let hasFolder = 1;
31683169
}
31693170

31703171
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,29 @@ LogicalResult cir::VecShuffleDynamicOp::verify() {
10981098
return success();
10991099
}
11001100

1101+
//===----------------------------------------------------------------------===//
1102+
// VecExtractOp
1103+
//===----------------------------------------------------------------------===//
1104+
1105+
OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
1106+
const auto vectorAttr =
1107+
llvm::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec());
1108+
if (!vectorAttr)
1109+
return {};
1110+
1111+
const auto indexAttr =
1112+
llvm::dyn_cast_if_present<cir::IntAttr>(adaptor.getIndex());
1113+
if (!indexAttr)
1114+
return {};
1115+
1116+
const mlir::ArrayAttr elements = vectorAttr.getElts();
1117+
const uint64_t index = indexAttr.getUInt();
1118+
if (index >= elements.size())
1119+
return {};
1120+
1121+
return elements[index];
1122+
}
1123+
11011124
//===----------------------------------------------------------------------===//
11021125
// ReturnOp
11031126
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ struct RemoveEmptyScope : public OpRewritePattern<ScopeOp> {
7979
struct RemoveEmptySwitch : public OpRewritePattern<SwitchOp> {
8080
using OpRewritePattern<SwitchOp>::OpRewritePattern;
8181

82-
LogicalResult matchAndRewrite(SwitchOp op, PatternRewriter &rewriter) const final {
83-
if (!(op.getBody().empty() ||
84-
isa<YieldOp>(op.getBody().front().front())))
82+
LogicalResult matchAndRewrite(SwitchOp op,
83+
PatternRewriter &rewriter) const final {
84+
if (!(op.getBody().empty() || isa<YieldOp>(op.getBody().front().front())))
8585
return failure();
8686

8787
rewriter.eraseOp(op);
@@ -92,7 +92,8 @@ struct RemoveEmptySwitch : public OpRewritePattern<SwitchOp> {
9292
struct RemoveTrivialTry : public OpRewritePattern<TryOp> {
9393
using OpRewritePattern<TryOp>::OpRewritePattern;
9494

95-
LogicalResult matchAndRewrite(TryOp op, PatternRewriter &rewriter) const final {
95+
LogicalResult matchAndRewrite(TryOp op,
96+
PatternRewriter &rewriter) const final {
9697
// FIXME: also check all catch regions are empty
9798
// return success(op.getTryRegion().hasOneBlock());
9899
return mlir::failure();
@@ -116,7 +117,8 @@ struct RemoveTrivialTry : public OpRewritePattern<TryOp> {
116117
struct SimplifyCallOp : public OpRewritePattern<CallOp> {
117118
using OpRewritePattern<CallOp>::OpRewritePattern;
118119

119-
LogicalResult matchAndRewrite(CallOp op, PatternRewriter &rewriter) const final {
120+
LogicalResult matchAndRewrite(CallOp op,
121+
PatternRewriter &rewriter) const final {
120122
// Applicable to cir.call exception ... clean { cir.yield }
121123
mlir::Region *r = &op.getCleanup();
122124
if (r->empty() || !r->hasOneBlock())
@@ -174,10 +176,11 @@ void CIRCanonicalizePass::runOnOperation() {
174176
// Collect operations to apply patterns.
175177
llvm::SmallVector<Operation *, 16> ops;
176178
getOperation()->walk([&](Operation *op) {
177-
// CastOp and UnaryOp are here to perform a manual `fold` in
179+
// CastOp, UnaryOp and VecExtractOp are here to perform a manual `fold` in
178180
// applyOpPatternsGreedily.
179181
if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp, CastOp, TryOp, UnaryOp, SelectOp,
180-
ComplexCreateOp, ComplexRealOp, ComplexImagOp, CallOp>(op))
182+
ComplexCreateOp, ComplexRealOp, ComplexImagOp, CallOp,
183+
VecExtractOp>(op))
181184
ops.push_back(op);
182185
});
183186

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: cir-opt %s -cir-canonicalize -o - | FileCheck %s
2+
3+
!s32i = !cir.int<s, 32>
4+
5+
module {
6+
cir.func @fold_extract_vector_op_test() {
7+
%init = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
8+
%const_vec = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<!s32i x 4>
9+
%index = cir.const #cir.int<1> : !s32i
10+
%ele = cir.vec.extract %const_vec[%index : !s32i] : !cir.vector<!s32i x 4>
11+
cir.store %ele, %init : !s32i, !cir.ptr<!s32i>
12+
cir.return
13+
}
14+
15+
// CHECK: %[[INIT:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
16+
// CHECK: %[[VALUE:.*]] = cir.const #cir.int<2> : !s32i
17+
// CHECK: cir.store %[[VALUE]], %[[INIT]] : !s32i, !cir.ptr<!s32i>
18+
19+
cir.func @fold_extract_vector_op_index_out_of_bounds_test() {
20+
%init = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
21+
%const_vec = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<!s32i x 4>
22+
%index = cir.const #cir.int<9> : !s32i
23+
%ele = cir.vec.extract %const_vec[%index : !s32i] : !cir.vector<!s32i x 4>
24+
cir.store %ele, %init : !s32i, !cir.ptr<!s32i>
25+
cir.return
26+
}
27+
28+
// CHECK: %[[INIT:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
29+
// CHECK: %[[CONST_VEC:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<!s32i x 4>
30+
// CHECK: %[[INDEX:.*]] = cir.const #cir.int<9> : !s32i
31+
// CHECK: %[[ELE:.*]] = cir.vec.extract %[[CONST_VEC]][%[[INDEX]] : !s32i] : !cir.vector<!s32i x 4>
32+
// CHECK: cir.store %[[ELE]], %[[INIT]] : !s32i, !cir.ptr<!s32i>
33+
}

0 commit comments

Comments
 (0)