diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 93f89eda2da5a..2ec1b97f2f241 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -437,6 +437,9 @@ void VectorDialect::initialize() { Operation *VectorDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { + if (isa(value)) + return value.getDialect().materializeConstant(builder, value, type, loc); + return arith::ConstantOp::materialize(builder, value, type, loc); } @@ -2273,20 +2276,6 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp, return success(); } -/// Fold an insert or extract operation into an poison value when a poison index -/// is found at any dimension of the static position. -template -LogicalResult -canonicalizePoisonIndexInsertExtractOp(OpTy op, PatternRewriter &rewriter) { - if (auto poisonAttr = foldPoisonIndexInsertExtractOp( - op.getContext(), op.getStaticPosition(), OpTy::kPoisonIndex)) { - rewriter.replaceOpWithNewOp(op, op.getType(), poisonAttr); - return success(); - } - - return failure(); -} - } // namespace void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, @@ -2295,7 +2284,6 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context); results.add(foldExtractFromShapeCastToShapeCast); results.add(foldExtractFromFromElements); - results.add(canonicalizePoisonIndexInsertExtractOp); } static void populateFromInt64AttrArray(ArrayAttr arrayAttr, @@ -3068,7 +3056,6 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); - results.add(canonicalizePoisonIndexInsertExtractOp); } OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) { diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 7df6defc0f202..9a6337f14ace3 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1250,13 +1250,13 @@ func.func @extract_scalar_from_vec_1d_f32(%arg0: vector<16xf32>) -> f32 { // ----- -func.func @extract_poison_idx(%arg0: vector<16xf32>) -> f32 { +func.func @extract_scalar_from_vec_1d_f32_poison_idx(%arg0: vector<16xf32>) -> f32 { %0 = vector.extract %arg0[-1]: f32 from vector<16xf32> return %0 : f32 } -// CHECK-LABEL: @extract_poison_idx -// CHECK: %[[IDX:.*]] = llvm.mlir.constant(-1 : i64) : i64 -// CHECK: llvm.extractelement {{.*}}[%[[IDX]] : i64] : vector<16xf32> +// CHECK-LABEL: @extract_scalar_from_vec_1d_f32_poison_idx +// CHECK: %[[UB:.*]] = ub.poison : f32 +// CHECK: return %[[UB]] : f32 // ----- @@ -1335,6 +1335,16 @@ func.func @extract_vec_2d_from_vec_3d_f32(%arg0: vector<4x3x16xf32>) -> vector<3 // ----- +func.func @extract_vec_2d_from_vec_3d_f32_poison_idx(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32> { + %0 = vector.extract %arg0[-1]: vector<3x16xf32> from vector<4x3x16xf32> + return %0 : vector<3x16xf32> +} +// CHECK-LABEL: @extract_vec_2d_from_vec_3d_f32_poison_idx +// CHECK: %[[UB:.*]] = ub.poison : vector<3x16xf32> +// CHECK: return %[[UB]] : vector<3x16xf32> + +// ----- + func.func @extract_vec_2d_from_vec_3d_f32_scalable(%arg0: vector<4x3x[16]xf32>) -> vector<3x[16]xf32> { %0 = vector.extract %arg0[0]: vector<3x[16]xf32> from vector<4x3x[16]xf32> return %0 : vector<3x[16]xf32>