@@ -1986,12 +1986,26 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
19861986 return fromElementsOp.getElements ()[flatIndex];
19871987}
19881988
1989- OpFoldResult ExtractOp::fold (FoldAdaptor) {
1989+ // / Fold an insert or extract operation into an poison value when a poison index
1990+ // / is found at any dimension of the static position.
1991+ static ub::PoisonAttr
1992+ foldPoisonIndexInsertExtractOp (MLIRContext *context,
1993+ ArrayRef<int64_t > staticPos, int64_t poisonVal) {
1994+ if (!llvm::is_contained (staticPos, poisonVal))
1995+ return ub::PoisonAttr ();
1996+
1997+ return ub::PoisonAttr::get (context);
1998+ }
1999+
2000+ OpFoldResult ExtractOp::fold (FoldAdaptor adaptor) {
19902001 // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
19912002 // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
19922003 // mismatch).
19932004 if (getNumIndices () == 0 && getVector ().getType () == getResult ().getType ())
19942005 return getVector ();
2006+ if (auto res = foldPoisonIndexInsertExtractOp (
2007+ getContext (), adaptor.getStaticPosition (), kPoisonIndex ))
2008+ return res;
19952009 if (succeeded (foldExtractOpFromExtractChain (*this )))
19962010 return getResult ();
19972011 if (auto res = ExtractFromInsertTransposeChainState (*this ).fold ())
@@ -2262,13 +2276,15 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
22622276// / Fold an insert or extract operation into an poison value when a poison index
22632277// / is found at any dimension of the static position.
22642278template <typename OpTy>
2265- LogicalResult foldPoisonIndexInsertExtractOp (OpTy op,
2266- PatternRewriter &rewriter) {
2267- if (!llvm::is_contained (op.getStaticPosition (), OpTy::kPoisonIndex ))
2268- return failure ();
2279+ LogicalResult
2280+ canonicalizePoisonIndexInsertExtractOp (OpTy op, PatternRewriter &rewriter) {
2281+ if (auto poisonAttr = foldPoisonIndexInsertExtractOp (
2282+ op.getContext (), op.getStaticPosition (), OpTy::kPoisonIndex )) {
2283+ rewriter.replaceOpWithNewOp <ub::PoisonOp>(op, op.getType (), poisonAttr);
2284+ return success ();
2285+ }
22692286
2270- rewriter.replaceOpWithNewOp <ub::PoisonOp>(op, op.getResult ().getType ());
2271- return success ();
2287+ return failure ();
22722288}
22732289
22742290} // namespace
@@ -2279,7 +2295,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
22792295 ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
22802296 results.add (foldExtractFromShapeCastToShapeCast);
22812297 results.add (foldExtractFromFromElements);
2282- results.add (foldPoisonIndexInsertExtractOp <ExtractOp>);
2298+ results.add (canonicalizePoisonIndexInsertExtractOp <ExtractOp>);
22832299}
22842300
22852301static void populateFromInt64AttrArray (ArrayAttr arrayAttr,
@@ -3044,7 +3060,7 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
30443060 MLIRContext *context) {
30453061 results.add <InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
30463062 InsertOpConstantFolder>(context);
3047- results.add (foldPoisonIndexInsertExtractOp <InsertOp>);
3063+ results.add (canonicalizePoisonIndexInsertExtractOp <InsertOp>);
30483064}
30493065
30503066OpFoldResult vector::InsertOp::fold (FoldAdaptor adaptor) {
@@ -3053,6 +3069,10 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
30533069 // (type mismatch).
30543070 if (getNumIndices () == 0 && getSourceType () == getType ())
30553071 return getSource ();
3072+ if (auto res = foldPoisonIndexInsertExtractOp (
3073+ getContext (), adaptor.getStaticPosition (), kPoisonIndex ))
3074+ return res;
3075+
30563076 return {};
30573077}
30583078
0 commit comments