1919#include " mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
2020#include " mlir/Dialect/MemRef/IR/MemRef.h"
2121#include " mlir/Dialect/Tensor/IR/Tensor.h"
22+ #include " mlir/Dialect/UB/IR/UBOps.h"
2223#include " mlir/Dialect/Utils/IndexingUtils.h"
2324#include " mlir/Dialect/Utils/StructuredOpsUtils.h"
2425#include " mlir/IR/AffineExpr.h"
@@ -1274,6 +1275,13 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
12741275 return srcElements[posIdx];
12751276}
12761277
1278+ // Returns `true` if `index` is either within [0, maxIndex) or equal to
1279+ // `poisonValue`.
1280+ static bool isValidPositiveIndexOrPoison (int64_t index, int64_t poisonValue,
1281+ int64_t maxIndex) {
1282+ return index == poisonValue || (index >= 0 && index < maxIndex);
1283+ }
1284+
12771285// ===----------------------------------------------------------------------===//
12781286// ExtractOp
12791287// ===----------------------------------------------------------------------===//
@@ -1355,11 +1363,12 @@ LogicalResult vector::ExtractOp::verify() {
13551363 for (auto [idx, pos] : llvm::enumerate (position)) {
13561364 if (auto attr = dyn_cast<Attribute>(pos)) {
13571365 int64_t constIdx = cast<IntegerAttr>(attr).getInt ();
1358- if (constIdx < 0 || constIdx >= getSourceVectorType ().getDimSize (idx)) {
1366+ if (!isValidPositiveIndexOrPoison (
1367+ constIdx, kPoisonIndex , getSourceVectorType ().getDimSize (idx))) {
13591368 return emitOpError (" expected position attribute #" )
13601369 << (idx + 1 )
13611370 << " to be a non-negative integer smaller than the "
1362- " corresponding vector dimension" ;
1371+ " corresponding vector dimension or poison (-1) " ;
13631372 }
13641373 }
13651374 }
@@ -1977,12 +1986,26 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
19771986 return fromElementsOp.getElements ()[flatIndex];
19781987}
19791988
1980- OpFoldResult ExtractOp::fold (FoldAdaptor) {
1989+ // / Fold an insert or extract operation into an poison value when a poison index
1990+ // / is found at any dimension of the static position.
1991+ static ub::PoisonAttr
1992+ foldPoisonIndexInsertExtractOp (MLIRContext *context,
1993+ ArrayRef<int64_t > staticPos, int64_t poisonVal) {
1994+ if (!llvm::is_contained (staticPos, poisonVal))
1995+ return ub::PoisonAttr ();
1996+
1997+ return ub::PoisonAttr::get (context);
1998+ }
1999+
2000+ OpFoldResult ExtractOp::fold (FoldAdaptor adaptor) {
19812001 // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
19822002 // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
19832003 // mismatch).
19842004 if (getNumIndices () == 0 && getVector ().getType () == getResult ().getType ())
19852005 return getVector ();
2006+ if (auto res = foldPoisonIndexInsertExtractOp (
2007+ getContext (), adaptor.getStaticPosition (), kPoisonIndex ))
2008+ return res;
19862009 if (succeeded (foldExtractOpFromExtractChain (*this )))
19872010 return getResult ();
19882011 if (auto res = ExtractFromInsertTransposeChainState (*this ).fold ())
@@ -2249,6 +2272,21 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
22492272 resultType.getNumElements ()));
22502273 return success ();
22512274}
2275+
2276+ // / Fold an insert or extract operation into an poison value when a poison index
2277+ // / is found at any dimension of the static position.
2278+ template <typename OpTy>
2279+ LogicalResult
2280+ canonicalizePoisonIndexInsertExtractOp (OpTy op, PatternRewriter &rewriter) {
2281+ if (auto poisonAttr = foldPoisonIndexInsertExtractOp (
2282+ op.getContext (), op.getStaticPosition (), OpTy::kPoisonIndex )) {
2283+ rewriter.replaceOpWithNewOp <ub::PoisonOp>(op, op.getType (), poisonAttr);
2284+ return success ();
2285+ }
2286+
2287+ return failure ();
2288+ }
2289+
22522290} // namespace
22532291
22542292void ExtractOp::getCanonicalizationPatterns (RewritePatternSet &results,
@@ -2257,6 +2295,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
22572295 ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
22582296 results.add (foldExtractFromShapeCastToShapeCast);
22592297 results.add (foldExtractFromFromElements);
2298+ results.add (canonicalizePoisonIndexInsertExtractOp<ExtractOp>);
22602299}
22612300
22622301static void populateFromInt64AttrArray (ArrayAttr arrayAttr,
@@ -2600,7 +2639,7 @@ LogicalResult ShuffleOp::verify() {
26002639 int64_t indexSize = (v1Type.getRank () == 0 ? 1 : v1Type.getDimSize (0 )) +
26012640 (v2Type.getRank () == 0 ? 1 : v2Type.getDimSize (0 ));
26022641 for (auto [idx, maskPos] : llvm::enumerate (mask)) {
2603- if (maskPos != kMaskPoisonValue && (maskPos < 0 || maskPos >= indexSize))
2642+ if (! isValidPositiveIndexOrPoison (maskPos, kPoisonIndex , indexSize))
26042643 return emitOpError (" mask index #" ) << (idx + 1 ) << " out of range" ;
26052644 }
26062645 return success ();
@@ -2882,7 +2921,8 @@ LogicalResult InsertOp::verify() {
28822921 for (auto [idx, pos] : llvm::enumerate (position)) {
28832922 if (auto attr = pos.dyn_cast <Attribute>()) {
28842923 int64_t constIdx = cast<IntegerAttr>(attr).getInt ();
2885- if (constIdx < 0 || constIdx >= destVectorType.getDimSize (idx)) {
2924+ if (!isValidPositiveIndexOrPoison (constIdx, kPoisonIndex ,
2925+ destVectorType.getDimSize (idx))) {
28862926 return emitOpError (" expected position attribute #" )
28872927 << (idx + 1 )
28882928 << " to be a non-negative integer smaller than the "
@@ -3020,6 +3060,7 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
30203060 MLIRContext *context) {
30213061 results.add <InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
30223062 InsertOpConstantFolder>(context);
3063+ results.add (canonicalizePoisonIndexInsertExtractOp<InsertOp>);
30233064}
30243065
30253066OpFoldResult vector::InsertOp::fold (FoldAdaptor adaptor) {
@@ -3028,6 +3069,10 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
30283069 // (type mismatch).
30293070 if (getNumIndices () == 0 && getSourceType () == getType ())
30303071 return getSource ();
3072+ if (auto res = foldPoisonIndexInsertExtractOp (
3073+ getContext (), adaptor.getStaticPosition (), kPoisonIndex ))
3074+ return res;
3075+
30313076 return {};
30323077}
30333078
0 commit comments