diff --git a/mlir/include/mlir/Dialect/Vector/IR/Vector.td b/mlir/include/mlir/Dialect/Vector/IR/Vector.td index 5125ae7c13717..beb6bedb908e9 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/Vector.td +++ b/mlir/include/mlir/Dialect/Vector/IR/Vector.td @@ -23,7 +23,8 @@ def Vector_Dialect : Dialect { let hasConstantMaterializer = 1; let dependentDialects = [ "arith::ArithDialect", - "ub::UBDialect" + "ub::UBDialect", + "memref::MemRefDialect" ]; } diff --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt index 0248896e096a0..9cf4fedbbe978 100644 --- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt @@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRVectorDialect LINK_LIBS PUBLIC MLIRAffineDialect MLIRArithDialect + MLIRMemRefDialect MLIRControlFlowInterfaces MLIRDataLayoutInterfaces MLIRDestinationStyleOpInterface diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 2789f63555524..eceeed9d03f4b 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2226,6 +2226,99 @@ class ExtractOpFromBroadcast final : public OpRewritePattern { } }; +/// Check if the element type is suitable for vector.load/store sinking. +/// Element type must be index or byte-aligned integer or floating-point type. +static bool isSupportedMemSinkElementType(Type type) { + if (isa(type)) + return true; + + return type.isIntOrFloat() && type.getIntOrFloatBitWidth() % 8 == 0; +} + +/// Pattern to rewrite `vector.extract(vector.load) -> vector/memref.load. +/// Only index and byte-aligned integer and floating-point element types are +/// supported for now. +/// +/// Example: +/// ``` +/// vector.load %arg0[%arg1] : memref, vector<4xf32> +/// vector.extract %0[1] : f32 from vector<4xf32> +/// ``` +/// Gets converted to: +/// ``` +/// %c1 = arith.constant 1 : index +/// %0 = arith.addi %arg1, %c1 overflow : index +/// %1 = memref.load %arg0[%0] : memref +/// ``` +/// +/// Note, this is consider beneficial only in single-use cases. +class ExtractOpFromLoad final : public OpRewritePattern { +public: + using Base::Base; + + LogicalResult matchAndRewrite(vector::ExtractOp op, + PatternRewriter &rewriter) const override { + auto loadOp = op.getSource().getDefiningOp(); + if (!loadOp) + return rewriter.notifyMatchFailure(op, "expected a load op"); + + // Checking for single use so we won't duplicate load ops. + if (!loadOp->hasOneUse()) + return rewriter.notifyMatchFailure(op, "expected single op use"); + + VectorType loadVecType = loadOp.getVectorType(); + if (loadVecType.isScalable()) + return rewriter.notifyMatchFailure(op, + "scalable vectors are not supported"); + + MemRefType memType = loadOp.getMemRefType(); + + // Non-byte-aligned types are tricky and may require special handling, + // ignore them for now. + if (!isSupportedMemSinkElementType(memType.getElementType())) + return rewriter.notifyMatchFailure(op, "unsupported element type"); + + int64_t rankOffset = memType.getRank() - loadVecType.getRank(); + if (rankOffset < 0) + return rewriter.notifyMatchFailure(op, "unsupported ranks combination"); + + auto extractVecType = dyn_cast(op.getResult().getType()); + int64_t finalRank = 0; + if (extractVecType) + finalRank = extractVecType.getRank(); + + SmallVector indices = loadOp.getIndices(); + SmallVector extractPos = op.getMixedPosition(); + + // There may be memory stores between the load and the extract op, so we + // need to make sure that the new load op is inserted at the same place as + // the original load op. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(loadOp); + Location loc = loadOp.getLoc(); + ArithIndexingBuilder idxBuilderf(rewriter, loc); + for (auto i : llvm::seq(rankOffset, indices.size() - finalRank)) { + OpFoldResult pos = extractPos[i - rankOffset]; + if (isZeroInteger(pos)) + continue; + + Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos); + indices[i] = idxBuilderf.add(indices[i], offset); + } + + Value base = loadOp.getBase(); + if (extractVecType) { + rewriter.replaceOpWithNewOp(op, extractVecType, base, + indices); + } else { + rewriter.replaceOpWithNewOp(op, base, indices); + } + // We checked for single use so we can safely erase the load op. + rewriter.eraseOp(loadOp); + return success(); + } +}; + // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask. class ExtractOpFromCreateMask final : public OpRewritePattern { public: @@ -2363,7 +2456,9 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp, void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results + .add( + context); results.add(foldExtractFromShapeCastToShapeCast); results.add(foldExtractFromFromElements); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 726da1e9a3d14..2e1f8ff38dbf6 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -2384,8 +2384,7 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns, void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit) { // TODO: Consider converting these patterns to canonicalizations. - patterns.add(patterns.getContext(), - benefit); + patterns.add(patterns.getContext(), benefit); } void mlir::vector::populateChainedVectorReductionFoldingPatterns( diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir new file mode 100644 index 0000000000000..c140bcb3af8ad --- /dev/null +++ b/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir @@ -0,0 +1,131 @@ +// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file | FileCheck %s + +// This file contains some tests of folding/canonicalizing vector.extract + +//----------------------------------------------------------------------------- +// [Pattern: ExtractOpFromLoad] +//----------------------------------------------------------------------------- + +// CHECK-LABEL: @extract_load_scalar +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index) +func.func @extract_load_scalar(%arg0: memref, %arg1: index) -> f32 { +// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref +// CHECK: return %[[RES]] : f32 + %0 = vector.load %arg0[%arg1] : memref, vector<4xf32> + %1 = vector.extract %0[0] : f32 from vector<4xf32> + return %1 : f32 +} + +// CHECK-LABEL: @extract_load_index +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index) +func.func @extract_load_index(%arg0: memref, %arg1: index) -> index { +// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref +// CHECK: return %[[RES]] : index + %0 = vector.load %arg0[%arg1] : memref, vector<4xindex> + %1 = vector.extract %0[0] : index from vector<4xindex> + return %1 : index +} + +// CHECK-LABEL: @extract_load_scalar_non_zero_off +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index) +func.func @extract_load_scalar_non_zero_off(%arg0: memref, %arg1: index) -> f32 { +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow : index +// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref +// CHECK: return %[[RES]] : f32 + %0 = vector.load %arg0[%arg1] : memref, vector<4xf32> + %1 = vector.extract %0[1] : f32 from vector<4xf32> + return %1 : f32 +} + +// CHECK-LABEL: @extract_load_scalar_dyn_off +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) +func.func @extract_load_scalar_dyn_off(%arg0: memref, %arg1: index, %arg2: index) -> f32 { +// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[ARG2]] overflow : index +// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref +// CHECK: return %[[RES]] : f32 + %0 = vector.load %arg0[%arg1] : memref, vector<4xf32> + %1 = vector.extract %0[%arg2] : f32 from vector<4xf32> + return %1 : f32 +} + +// CHECK-LABEL: @extract_load_vec_non_zero_off +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) +func.func @extract_load_vec_non_zero_off(%arg0: memref, %arg1: index, %arg2: index) -> vector<4xf32> { +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow : index +// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[OFF]], %[[ARG2]]] : memref, vector<4xf32> +// CHECK: return %[[RES]] : vector<4xf32> + %0 = vector.load %arg0[%arg1, %arg2] : memref, vector<2x4xf32> + %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32> + return %1 : vector<4xf32> +} + +// CHECK-LABEL: @extract_load_scalar_non_zero_off_2d_src_memref +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) +func.func @extract_load_scalar_non_zero_off_2d_src_memref(%arg0: memref, %arg1: index, %arg2: index) -> f32 { +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow : index +// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[OFF]]] : memref +// CHECK: return %[[RES]] : f32 + %0 = vector.load %arg0[%arg1, %arg2] : memref, vector<4xf32> + %1 = vector.extract %0[1] : f32 from vector<4xf32> + return %1 : f32 +} + +// CHECK-LABEL: @extract_load_vec_high_rank +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) +func.func @extract_load_vec_high_rank(%arg0: memref, %arg1: index, %arg2: index, %arg3: index) -> vector<4xf32> { +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow : index +// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[OFF]], %[[ARG3]]] : memref, vector<4xf32> +// CHECK: return %[[RES]] : vector<4xf32> + %0 = vector.load %arg0[%arg1, %arg2, %arg3] : memref, vector<2x4xf32> + %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32> + return %1 : vector<4xf32> +} + +// CHECK-LABEL: @negative_extract_load_scalar_from_memref_of_vec +// CHECK-SAME: (%[[ARG0:.*]]: memref>, %[[ARG1:.*]]: index) +func.func @negative_extract_load_scalar_from_memref_of_vec(%arg0: memref>, %arg1: index) -> f32 { +// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref>, vector<4xf32> +// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32> +// CHECK: return %[[EXT]] : f32 + %0 = vector.load %arg0[%arg1] : memref>, vector<4xf32> + %1 = vector.extract %0[0] : f32 from vector<4xf32> + return %1 : f32 +} + +// CHECK-LABEL: @negative_extract_load_scalar_from_memref_of_i1 +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index) +func.func @negative_extract_load_scalar_from_memref_of_i1(%arg0: memref, %arg1: index) -> i1 { +// Subbyte types are tricky, ignore them for now. +// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref, vector<8xi1> +// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : i1 from vector<8xi1> +// CHECK: return %[[EXT]] : i1 + %0 = vector.load %arg0[%arg1] : memref, vector<8xi1> + %1 = vector.extract %0[0] : i1 from vector<8xi1> + return %1 : i1 +} + +// CHECK-LABEL: @negative_extract_load_no_single_use +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index) +func.func @negative_extract_load_no_single_use(%arg0: memref, %arg1: index) -> (f32, vector<4xf32>) { +// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref, vector<4xf32> +// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32> +// CHECK: return %[[EXT]], %[[RES]] : f32, vector<4xf32> + %0 = vector.load %arg0[%arg1] : memref, vector<4xf32> + %1 = vector.extract %0[0] : f32 from vector<4xf32> + return %1, %0 : f32, vector<4xf32> +} + +// CHECK-LABEL: @negative_extract_load_scalable +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index) +func.func @negative_extract_load_scalable(%arg0: memref, %arg1: index) -> f32 { +// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref, vector<[1]xf32> +// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<[1]xf32> +// CHECK: return %[[EXT]] : f32 + %0 = vector.load %arg0[%arg1] : memref, vector<[1]xf32> + %1 = vector.extract %0[0] : f32 from vector<[1]xf32> + return %1 : f32 +} diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir index beaba52af1841..5d6ea5147fa73 100644 --- a/mlir/test/Dialect/Vector/vector-sink.mlir +++ b/mlir/test/Dialect/Vector/vector-sink.mlir @@ -11,6 +11,7 @@ // CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex> // CHECK: return %[[BCAST]] : vector<1x4xindex> + func.func @broadcast_scalar_with_bcast(%arg1: index, %arg2: index) -> vector<1x4xindex> { %0 = vector.broadcast %arg1 : index to vector<1x4xindex> %1 = vector.broadcast %arg2 : index to vector<1x4xindex> @@ -651,133 +652,6 @@ func.func @negative_extract_dynamic_pos(%arg0: vector<4xf32>, %arg1 : vector<4xf return %2 : f32 } -//----------------------------------------------------------------------------- -// [Pattern: ExtractOpFromLoad] -//----------------------------------------------------------------------------- - -// CHECK-LABEL: @extract_load_scalar -// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index) -func.func @extract_load_scalar(%arg0: memref, %arg1: index) -> f32 { -// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref -// CHECK: return %[[RES]] : f32 - %0 = vector.load %arg0[%arg1] : memref, vector<4xf32> - %1 = vector.extract %0[0] : f32 from vector<4xf32> - return %1 : f32 -} - -// CHECK-LABEL: @extract_load_index -// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index) -func.func @extract_load_index(%arg0: memref, %arg1: index) -> index { -// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref -// CHECK: return %[[RES]] : index - %0 = vector.load %arg0[%arg1] : memref, vector<4xindex> - %1 = vector.extract %0[0] : index from vector<4xindex> - return %1 : index -} - -// CHECK-LABEL: @extract_load_scalar_non_zero_off -// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index) -func.func @extract_load_scalar_non_zero_off(%arg0: memref, %arg1: index) -> f32 { -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow : index -// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref -// CHECK: return %[[RES]] : f32 - %0 = vector.load %arg0[%arg1] : memref, vector<4xf32> - %1 = vector.extract %0[1] : f32 from vector<4xf32> - return %1 : f32 -} - -// CHECK-LABEL: @extract_load_scalar_dyn_off -// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -func.func @extract_load_scalar_dyn_off(%arg0: memref, %arg1: index, %arg2: index) -> f32 { -// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[ARG2]] overflow : index -// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref -// CHECK: return %[[RES]] : f32 - %0 = vector.load %arg0[%arg1] : memref, vector<4xf32> - %1 = vector.extract %0[%arg2] : f32 from vector<4xf32> - return %1 : f32 -} - -// CHECK-LABEL: @extract_load_vec_non_zero_off -// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -func.func @extract_load_vec_non_zero_off(%arg0: memref, %arg1: index, %arg2: index) -> vector<4xf32> { -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow : index -// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[OFF]], %[[ARG2]]] : memref, vector<4xf32> -// CHECK: return %[[RES]] : vector<4xf32> - %0 = vector.load %arg0[%arg1, %arg2] : memref, vector<2x4xf32> - %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32> - return %1 : vector<4xf32> -} - -// CHECK-LABEL: @extract_load_scalar_non_zero_off_2d_src_memref -// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -func.func @extract_load_scalar_non_zero_off_2d_src_memref(%arg0: memref, %arg1: index, %arg2: index) -> f32 { -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow : index -// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[OFF]]] : memref -// CHECK: return %[[RES]] : f32 - %0 = vector.load %arg0[%arg1, %arg2] : memref, vector<4xf32> - %1 = vector.extract %0[1] : f32 from vector<4xf32> - return %1 : f32 -} - -// CHECK-LABEL: @extract_load_vec_high_rank -// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -func.func @extract_load_vec_high_rank(%arg0: memref, %arg1: index, %arg2: index, %arg3: index) -> vector<4xf32> { -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow : index -// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[OFF]], %[[ARG3]]] : memref, vector<4xf32> -// CHECK: return %[[RES]] : vector<4xf32> - %0 = vector.load %arg0[%arg1, %arg2, %arg3] : memref, vector<2x4xf32> - %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32> - return %1 : vector<4xf32> -} - -// CHECK-LABEL: @negative_extract_load_scalar_from_memref_of_vec -// CHECK-SAME: (%[[ARG0:.*]]: memref>, %[[ARG1:.*]]: index) -func.func @negative_extract_load_scalar_from_memref_of_vec(%arg0: memref>, %arg1: index) -> f32 { -// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref>, vector<4xf32> -// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32> -// CHECK: return %[[EXT]] : f32 - %0 = vector.load %arg0[%arg1] : memref>, vector<4xf32> - %1 = vector.extract %0[0] : f32 from vector<4xf32> - return %1 : f32 -} - -// CHECK-LABEL: @negative_extract_load_scalar_from_memref_of_i1 -// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index) -func.func @negative_extract_load_scalar_from_memref_of_i1(%arg0: memref, %arg1: index) -> i1 { -// Subbyte types are tricky, ignore them for now. -// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref, vector<8xi1> -// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : i1 from vector<8xi1> -// CHECK: return %[[EXT]] : i1 - %0 = vector.load %arg0[%arg1] : memref, vector<8xi1> - %1 = vector.extract %0[0] : i1 from vector<8xi1> - return %1 : i1 -} - -// CHECK-LABEL: @negative_extract_load_no_single_use -// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index) -func.func @negative_extract_load_no_single_use(%arg0: memref, %arg1: index) -> (f32, vector<4xf32>) { -// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref, vector<4xf32> -// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32> -// CHECK: return %[[EXT]], %[[RES]] : f32, vector<4xf32> - %0 = vector.load %arg0[%arg1] : memref, vector<4xf32> - %1 = vector.extract %0[0] : f32 from vector<4xf32> - return %1, %0 : f32, vector<4xf32> -} - -// CHECK-LABEL: @negative_extract_load_scalable -// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index) -func.func @negative_extract_load_scalable(%arg0: memref, %arg1: index) -> f32 { -// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref, vector<[1]xf32> -// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<[1]xf32> -// CHECK: return %[[EXT]] : f32 - %0 = vector.load %arg0[%arg1] : memref, vector<[1]xf32> - %1 = vector.extract %0[0] : f32 from vector<[1]xf32> - return %1 : f32 -} //----------------------------------------------------------------------------- // [Pattern: StoreOpFromBroadcast] diff --git a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir index 0bb7d7d3d8b1b..cd8fdb67fab89 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir @@ -168,8 +168,7 @@ func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2 // CHECK: %4 = vector.extract %arg1[0, 0] // CHECK: %5 = arith.addi %4, %c1 // CHECK: %6 = scf.if %3 -> (vector<3xf32>) { -// CHECK: %{{.*}} = vector.load %arg0[%c0, %5] : memref, vector<1xf32> -// CHECK: %{{.*}} = vector.extract {{.*}}[0] : f32 +// CHECK: %{{.*}} = memref.load %arg0[%c0, %5] : memref // CHECK: %{{.*}} = vector.insert {{.*}}, %2 [0] : f32 into vector<3xf32> // CHECK: scf.yield {{.*}} : vector<3xf32> // CHECK: } else {