Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2199,7 +2199,9 @@ def VecShuffleOp : CIR_Op<"vec.shuffle",
`(` $vec1 `,` $vec2 `:` qualified(type($vec1)) `)` $indices `:`
qualified(type($result)) attr-dict
}];

let hasVerifier = 1;
let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
45 changes: 42 additions & 3 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1580,9 +1580,49 @@ OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
}

//===----------------------------------------------------------------------===//
// VecShuffle
// VecShuffleOp
//===----------------------------------------------------------------------===//

OpFoldResult cir::VecShuffleOp::fold(FoldAdaptor adaptor) {
mlir::Attribute vec1 = adaptor.getVec1();
mlir::Attribute vec2 = adaptor.getVec2();

if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(vec1) ||
!mlir::isa_and_nonnull<cir::ConstVectorAttr>(vec2)) {
return {};
}

auto vec1Attr = mlir::cast<cir::ConstVectorAttr>(vec1);
auto vec2Attr = mlir::cast<cir::ConstVectorAttr>(vec2);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mlir::Attribute vec1 = adaptor.getVec1();
mlir::Attribute vec2 = adaptor.getVec2();
if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(vec1) ||
!mlir::isa_and_nonnull<cir::ConstVectorAttr>(vec2)) {
return {};
}
auto vec1Attr = mlir::cast<cir::ConstVectorAttr>(vec1);
auto vec2Attr = mlir::cast<cir::ConstVectorAttr>(vec2);
auto vec1Attr = mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec1());
auto vec2Attr = mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec2()));
if (!vec1Attr || !vec2Attr)
return {};

mlir::Type vec1ElemTy =
mlir::cast<cir::VectorType>(vec1Attr.getType()).getElementType();

mlir::ArrayAttr vec1Elts = vec1Attr.getElts();
mlir::ArrayAttr vec2Elts = vec2Attr.getElts();
mlir::ArrayAttr indicesElts = adaptor.getIndices();

SmallVector<mlir::Attribute, 16> elements;
elements.reserve(indicesElts.size());

uint64_t vec1Size = vec1Elts.size();
for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
if (idxAttr.getSInt() == -1) {
elements.push_back(cir::UndefAttr::get(vec1ElemTy));
continue;
}

uint64_t idxValue = idxAttr.getUInt();
if (idxValue < vec1Size) {
elements.push_back(vec1Elts[idxValue]);
} else {
elements.push_back(vec2Elts[idxValue - vec1Size]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

elements.push_back(idxValue < vec1Size ? vec1Elts[idxValue] : vec2Elts[idxValue - vec1Size])

}
}

return cir::ConstVectorAttr::get(
getType(), mlir::ArrayAttr::get(getContext(), elements));
}

LogicalResult cir::VecShuffleOp::verify() {
// The number of elements in the indices array must match the number of
// elements in the result type.
Expand Down Expand Up @@ -1613,7 +1653,6 @@ OpFoldResult cir::VecShuffleDynamicOp::fold(FoldAdaptor adaptor) {
mlir::isa_and_nonnull<cir::ConstVectorAttr>(indices)) {
auto vecAttr = mlir::cast<cir::ConstVectorAttr>(vec);
auto indicesAttr = mlir::cast<cir::ConstVectorAttr>(indices);
auto vecTy = mlir::cast<cir::VectorType>(vecAttr.getType());

mlir::ArrayAttr vecElts = vecAttr.getElts();
mlir::ArrayAttr indicesElts = indicesAttr.getElts();
Expand All @@ -1631,7 +1670,7 @@ OpFoldResult cir::VecShuffleDynamicOp::fold(FoldAdaptor adaptor) {
}

return cir::ConstVectorAttr::get(
vecTy, mlir::ArrayAttr::get(getContext(), elements));
getType(), mlir::ArrayAttr::get(getContext(), elements));
}

return {};
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ void CIRCanonicalizePass::runOnOperation() {
// Many operations are here to perform a manual `fold` in
// applyOpPatternsGreedily.
if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp,
VecExtractOp, VecShuffleDynamicOp, VecTernaryOp>(op))
VecExtractOp, VecShuffleOp, VecShuffleDynamicOp, VecTernaryOp>(op))
ops.push_back(op);
});

Expand Down
59 changes: 59 additions & 0 deletions clang/test/CIR/Transforms/vector-shuffle-fold.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// RUN: cir-opt %s -cir-canonicalize -o - -split-input-file | FileCheck %s

!s32i = !cir.int<s, 32>
!s64i = !cir.int<s, 64>

module {
cir.func @fold_shuffle_vector_op_test() -> !cir.vector<4 x !s32i> {
%vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i>
%vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
%new_vec = cir.vec.shuffle(%vec_1, %vec_2 : !cir.vector<4 x !s32i>) [#cir.int<0> : !s64i, #cir.int<4> : !s64i,
#cir.int<1> : !s64i, #cir.int<5> : !s64i] : !cir.vector<4 x !s32i>
cir.return %new_vec : !cir.vector<4 x !s32i>
}

// CHECK: cir.func @fold_shuffle_vector_op_test() -> !cir.vector<4 x !s32i> {
// CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i,
// CHECK-SAME: #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i>
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
}

// -----

!s32i = !cir.int<s, 32>
!s64i = !cir.int<s, 64>

module {
cir.func @fold_shuffle_vector_op_test() -> !cir.vector<6 x !s32i> {
%vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i>
%vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
%new_vec = cir.vec.shuffle(%vec_1, %vec_2 : !cir.vector<4 x !s32i>) [#cir.int<0> : !s64i, #cir.int<4> : !s64i,
#cir.int<1> : !s64i, #cir.int<5> : !s64i, #cir.int<2> : !s64i, #cir.int<6> : !s64i] : !cir.vector<6 x !s32i>
cir.return %new_vec : !cir.vector<6 x !s32i>
}

// CHECK: cir.func @fold_shuffle_vector_op_test() -> !cir.vector<6 x !s32i> {
// CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i,
// CHECK-SAME: #cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i]> : !cir.vector<6 x !s32i>
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<6 x !s32i>
}

// -----

!s32i = !cir.int<s, 32>
!s64i = !cir.int<s, 64>

module {
cir.func @fold_shuffle_vector_op_test() -> !cir.vector<4 x !s32i> {
%vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i>
%vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
%new_vec = cir.vec.shuffle(%vec_1, %vec_2 : !cir.vector<4 x !s32i>) [#cir.int<-1> : !s64i, #cir.int<4> : !s64i,
#cir.int<1> : !s64i, #cir.int<5> : !s64i] : !cir.vector<4 x !s32i>
cir.return %new_vec : !cir.vector<4 x !s32i>
}

// CHECK: cir.func @fold_shuffle_vector_op_test() -> !cir.vector<6 x !s32i> {
// CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.undef : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i,
// CHECK-SAME: #cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i]> : !cir.vector<6 x !s32i>
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<6 x !s32i>
}