1919#include " mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
2020#include " mlir/Dialect/MemRef/IR/MemRef.h"
2121#include " mlir/Dialect/Tensor/IR/Tensor.h"
22+ #include " mlir/Dialect/UB/IR/UBOps.h"
2223#include " mlir/Dialect/Utils/IndexingUtils.h"
2324#include " mlir/Dialect/Utils/StructuredOpsUtils.h"
2425#include " mlir/IR/AffineExpr.h"
@@ -1274,6 +1275,13 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
12741275 return srcElements[posIdx];
12751276}
12761277
1278+ // Returns `true` if `index` is either within [0, maxIndex) or equal to
1279+ // `poisonValue`.
1280+ static bool isValidPositiveIndexOrPoison (int64_t index, int64_t poisonValue,
1281+ int64_t maxIndex) {
1282+ return index == poisonValue || (index >= 0 && index < maxIndex);
1283+ }
1284+
12771285// ===----------------------------------------------------------------------===//
12781286// ExtractOp
12791287// ===----------------------------------------------------------------------===//
@@ -1355,7 +1363,8 @@ LogicalResult vector::ExtractOp::verify() {
13551363 for (auto [idx, pos] : llvm::enumerate (position)) {
13561364 if (auto attr = dyn_cast<Attribute>(pos)) {
13571365 int64_t constIdx = cast<IntegerAttr>(attr).getInt ();
1358- if (constIdx < 0 || constIdx >= getSourceVectorType ().getDimSize (idx)) {
1366+ if (!isValidPositiveIndexOrPoison (
1367+ constIdx, kPoisonIndex , getSourceVectorType ().getDimSize (idx))) {
13591368 return emitOpError (" expected position attribute #" )
13601369 << (idx + 1 )
13611370 << " to be a non-negative integer smaller than the "
@@ -2249,6 +2258,23 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
22492258 resultType.getNumElements ()));
22502259 return success ();
22512260}
2261+
2262+ // / Fold an insert or extract operation into an poison value when a poison index
2263+ // / is found at any dimension of the static position.
2264+ template <typename OpTy>
2265+ LogicalResult foldPoisonIndexInsertExtractOp (OpTy op,
2266+ PatternRewriter &rewriter) {
2267+ auto hasPoisonIndex = [](int64_t index) {
2268+ return index == OpTy::kPoisonIndex ;
2269+ };
2270+
2271+ if (llvm::none_of (op.getStaticPosition (), hasPoisonIndex))
2272+ return failure ();
2273+
2274+ rewriter.replaceOpWithNewOp <ub::PoisonOp>(op, op.getResult ().getType ());
2275+ return success ();
2276+ }
2277+
22522278} // namespace
22532279
22542280void ExtractOp::getCanonicalizationPatterns (RewritePatternSet &results,
@@ -2257,6 +2283,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
22572283 ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
22582284 results.add (foldExtractFromShapeCastToShapeCast);
22592285 results.add (foldExtractFromFromElements);
2286+ results.add (foldPoisonIndexInsertExtractOp<ExtractOp>);
22602287}
22612288
22622289static void populateFromInt64AttrArray (ArrayAttr arrayAttr,
@@ -2600,7 +2627,7 @@ LogicalResult ShuffleOp::verify() {
26002627 int64_t indexSize = (v1Type.getRank () == 0 ? 1 : v1Type.getDimSize (0 )) +
26012628 (v2Type.getRank () == 0 ? 1 : v2Type.getDimSize (0 ));
26022629 for (auto [idx, maskPos] : llvm::enumerate (mask)) {
2603- if (maskPos != kMaskPoisonValue && (maskPos < 0 || maskPos >= indexSize))
2630+ if (! isValidPositiveIndexOrPoison (maskPos, kPoisonIndex , indexSize))
26042631 return emitOpError (" mask index #" ) << (idx + 1 ) << " out of range" ;
26052632 }
26062633 return success ();
@@ -2882,7 +2909,8 @@ LogicalResult InsertOp::verify() {
28822909 for (auto [idx, pos] : llvm::enumerate (position)) {
28832910 if (auto attr = pos.dyn_cast <Attribute>()) {
28842911 int64_t constIdx = cast<IntegerAttr>(attr).getInt ();
2885- if (constIdx < 0 || constIdx >= destVectorType.getDimSize (idx)) {
2912+ if (!isValidPositiveIndexOrPoison (constIdx, kPoisonIndex ,
2913+ destVectorType.getDimSize (idx))) {
28862914 return emitOpError (" expected position attribute #" )
28872915 << (idx + 1 )
28882916 << " to be a non-negative integer smaller than the "
@@ -3020,6 +3048,7 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
30203048 MLIRContext *context) {
30213049 results.add <InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
30223050 InsertOpConstantFolder>(context);
3051+ results.add (foldPoisonIndexInsertExtractOp<InsertOp>);
30233052}
30243053
30253054OpFoldResult vector::InsertOp::fold (FoldAdaptor adaptor) {
0 commit comments