@@ -1643,34 +1643,31 @@ OpFoldResult cir::VecTernaryOp::fold(FoldAdaptor adaptor) {
16431643 mlir::Attribute lhs = adaptor.getLhs ();
16441644 mlir::Attribute rhs = adaptor.getRhs ();
16451645
1646- if (mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) &&
1647- mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) &&
1648- mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs)) {
1649- auto condVec = mlir::cast<cir::ConstVectorAttr>(cond);
1650- auto lhsVec = mlir::cast<cir::ConstVectorAttr>(lhs);
1651- auto rhsVec = mlir::cast<cir::ConstVectorAttr>(rhs);
1652-
1653- mlir::ArrayAttr condElts = condVec.getElts ();
1646+ if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) ||
1647+ !mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) ||
1648+ !mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs))
1649+ return {};
1650+ auto condVec = mlir::cast<cir::ConstVectorAttr>(cond);
1651+ auto lhsVec = mlir::cast<cir::ConstVectorAttr>(lhs);
1652+ auto rhsVec = mlir::cast<cir::ConstVectorAttr>(rhs);
16541653
1655- SmallVector<mlir::Attribute, 16 > elements;
1656- elements.reserve (condElts.size ());
1654+ mlir::ArrayAttr condElts = condVec.getElts ();
16571655
1658- for (const auto &[idx, condAttr] :
1659- llvm::enumerate (condElts.getAsRange <cir::IntAttr>())) {
1660- if (condAttr.getSInt ()) {
1661- elements.push_back (lhsVec.getElts ()[idx]);
1662- continue ;
1663- }
1656+ SmallVector<mlir::Attribute, 16 > elements;
1657+ elements.reserve (condElts.size ());
16641658
1659+ for (const auto &[idx, condAttr] :
1660+ llvm::enumerate (condElts.getAsRange <cir::IntAttr>())) {
1661+ if (condAttr.getSInt ()) {
1662+ elements.push_back (lhsVec.getElts ()[idx]);
1663+ } else {
16651664 elements.push_back (rhsVec.getElts ()[idx]);
16661665 }
1667-
1668- cir::VectorType vecTy = getLhs ().getType ();
1669- return cir::ConstVectorAttr::get (
1670- vecTy, mlir::ArrayAttr::get (getContext (), elements));
16711666 }
16721667
1673- return {};
1668+ cir::VectorType vecTy = getLhs ().getType ();
1669+ return cir::ConstVectorAttr::get (
1670+ vecTy, mlir::ArrayAttr::get (getContext (), elements));
16741671}
16751672
16761673// ===----------------------------------------------------------------------===//
0 commit comments