Skip to content

Commit 14ffd45

Browse files
committed
[CIR] Implement folder for VecShuffleOp
1 parent 995d74f commit 14ffd45

File tree

4 files changed

+77
-4
lines changed

4 files changed

+77
-4
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2199,7 +2199,9 @@ def VecShuffleOp : CIR_Op<"vec.shuffle",
21992199
`(` $vec1 `,` $vec2 `:` qualified(type($vec1)) `)` $indices `:`
22002200
qualified(type($result)) attr-dict
22012201
}];
2202+
22022203
let hasVerifier = 1;
2204+
let hasFolder = 1;
22032205
}
22042206

22052207
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,9 +1580,42 @@ OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
15801580
}
15811581

15821582
//===----------------------------------------------------------------------===//
1583-
// VecShuffle
1583+
// VecShuffleOp
15841584
//===----------------------------------------------------------------------===//
15851585

1586+
OpFoldResult cir::VecShuffleOp::fold(FoldAdaptor adaptor) {
1587+
mlir::Attribute vec1 = adaptor.getVec1();
1588+
mlir::Attribute vec2 = adaptor.getVec2();
1589+
1590+
if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(vec1) ||
1591+
!mlir::isa_and_nonnull<cir::ConstVectorAttr>(vec2)) {
1592+
return {};
1593+
}
1594+
1595+
auto vec1Attr = mlir::cast<cir::ConstVectorAttr>(vec1);
1596+
auto vec2Attr = mlir::cast<cir::ConstVectorAttr>(vec2);
1597+
1598+
mlir::ArrayAttr vec1Elts = vec1Attr.getElts();
1599+
mlir::ArrayAttr vec2Elts = vec2Attr.getElts();
1600+
mlir::ArrayAttr indicesElts = adaptor.getIndices();
1601+
1602+
SmallVector<mlir::Attribute, 16> elements;
1603+
elements.reserve(indicesElts.size());
1604+
1605+
uint64_t vec1Size = vec1Elts.size();
1606+
for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
1607+
uint64_t idxValue = idxAttr.getUInt();
1608+
if (idxValue < vec1Size) {
1609+
elements.push_back(vec1Elts[idxValue]);
1610+
} else {
1611+
elements.push_back(vec2Elts[idxValue - vec1Size]);
1612+
}
1613+
}
1614+
1615+
return cir::ConstVectorAttr::get(
1616+
getType(), mlir::ArrayAttr::get(getContext(), elements));
1617+
}
1618+
15861619
LogicalResult cir::VecShuffleOp::verify() {
15871620
// The number of elements in the indices array must match the number of
15881621
// elements in the result type.
@@ -1613,7 +1646,6 @@ OpFoldResult cir::VecShuffleDynamicOp::fold(FoldAdaptor adaptor) {
16131646
mlir::isa_and_nonnull<cir::ConstVectorAttr>(indices)) {
16141647
auto vecAttr = mlir::cast<cir::ConstVectorAttr>(vec);
16151648
auto indicesAttr = mlir::cast<cir::ConstVectorAttr>(indices);
1616-
auto vecTy = mlir::cast<cir::VectorType>(vecAttr.getType());
16171649

16181650
mlir::ArrayAttr vecElts = vecAttr.getElts();
16191651
mlir::ArrayAttr indicesElts = indicesAttr.getElts();
@@ -1631,7 +1663,7 @@ OpFoldResult cir::VecShuffleDynamicOp::fold(FoldAdaptor adaptor) {
16311663
}
16321664

16331665
return cir::ConstVectorAttr::get(
1634-
vecTy, mlir::ArrayAttr::get(getContext(), elements));
1666+
getType(), mlir::ArrayAttr::get(getContext(), elements));
16351667
}
16361668

16371669
return {};

clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ void CIRCanonicalizePass::runOnOperation() {
142142
// Many operations are here to perform a manual `fold` in
143143
// applyOpPatternsGreedily.
144144
if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp,
145-
VecExtractOp, VecShuffleDynamicOp, VecTernaryOp>(op))
145+
VecExtractOp, VecShuffleOp, VecShuffleDynamicOp, VecTernaryOp>(op))
146146
ops.push_back(op);
147147
});
148148

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// RUN: cir-opt %s -cir-canonicalize -o - -split-input-file | FileCheck %s
2+
3+
!s32i = !cir.int<s, 32>
4+
!s64i = !cir.int<s, 64>
5+
6+
module {
7+
cir.func @fold_shuffle_vector_op_test() -> !cir.vector<4 x !s32i> {
8+
%vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i>
9+
%vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
10+
%new_vec = cir.vec.shuffle(%vec_1, %vec_2 : !cir.vector<4 x !s32i>) [#cir.int<0> : !s64i, #cir.int<4> : !s64i,
11+
#cir.int<1> : !s64i, #cir.int<5> : !s64i] : !cir.vector<4 x !s32i>
12+
cir.return %new_vec : !cir.vector<4 x !s32i>
13+
}
14+
15+
// CHECK: cir.func @fold_shuffle_vector_op_test() -> !cir.vector<4 x !s32i> {
16+
// CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i,
17+
// CHECK-SAME: #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i>
18+
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
19+
}
20+
21+
// -----
22+
23+
!s32i = !cir.int<s, 32>
24+
!s64i = !cir.int<s, 64>
25+
26+
module {
27+
cir.func @fold_shuffle_vector_op_test() -> !cir.vector<6 x !s32i> {
28+
%vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i>
29+
%vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
30+
%new_vec = cir.vec.shuffle(%vec_1, %vec_2 : !cir.vector<4 x !s32i>) [#cir.int<0> : !s64i, #cir.int<4> : !s64i,
31+
#cir.int<1> : !s64i, #cir.int<5> : !s64i, #cir.int<2> : !s64i, #cir.int<6> : !s64i] : !cir.vector<6 x !s32i>
32+
cir.return %new_vec : !cir.vector<6 x !s32i>
33+
}
34+
35+
// CHECK: cir.func @fold_shuffle_vector_op_test() -> !cir.vector<6 x !s32i> {
36+
// CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i,
37+
// CHECK-SAME: #cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i]> : !cir.vector<6 x !s32i>
38+
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<6 x !s32i>
39+
}

0 commit comments

Comments
 (0)