Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Vector/IR/Vector.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def Vector_Dialect : Dialect {
let hasConstantMaterializer = 1;
let dependentDialects = [
"arith::ArithDialect",
"ub::UBDialect"
"ub::UBDialect",
"memref::MemRefDialect"
];
}

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Vector/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRVectorDialect
LINK_LIBS PUBLIC
MLIRAffineDialect
MLIRArithDialect
MLIRMemRefDialect
MLIRControlFlowInterfaces
MLIRDataLayoutInterfaces
MLIRDestinationStyleOpInterface
Expand Down
97 changes: 96 additions & 1 deletion mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2226,6 +2226,99 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
}
};

/// Check if the element type is suitable for vector.load/store sinking.
/// Element type must be index or byte-aligned integer or floating-point type.
static bool isSupportedMemSinkElementType(Type type) {
if (isa<IndexType>(type))
return true;

return type.isIntOrFloat() && type.getIntOrFloatBitWidth() % 8 == 0;
}

/// Pattern to rewrite `vector.extract(vector.load) -> vector/memref.load.
/// Only index and byte-aligned integer and floating-point element types are
/// supported for now.
///
/// Example:
/// ```
/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
/// vector.extract %0[1] : f32 from vector<4xf32>
/// ```
/// Gets converted to:
/// ```
/// %c1 = arith.constant 1 : index
/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
/// %1 = memref.load %arg0[%0] : memref<?xf32>
/// ```
///
/// Note, this is consider beneficial only in single-use cases.
class ExtractOpFromLoad final : public OpRewritePattern<ExtractOp> {
public:
using Base::Base;

LogicalResult matchAndRewrite(vector::ExtractOp op,
PatternRewriter &rewriter) const override {
auto loadOp = op.getSource().getDefiningOp<vector::LoadOp>();
if (!loadOp)
return rewriter.notifyMatchFailure(op, "expected a load op");

// Checking for single use so we won't duplicate load ops.
if (!loadOp->hasOneUse())
return rewriter.notifyMatchFailure(op, "expected single op use");

VectorType loadVecType = loadOp.getVectorType();
if (loadVecType.isScalable())
return rewriter.notifyMatchFailure(op,
"scalable vectors are not supported");

MemRefType memType = loadOp.getMemRefType();

// Non-byte-aligned types are tricky and may require special handling,
// ignore them for now.
if (!isSupportedMemSinkElementType(memType.getElementType()))
return rewriter.notifyMatchFailure(op, "unsupported element type");

int64_t rankOffset = memType.getRank() - loadVecType.getRank();
if (rankOffset < 0)
return rewriter.notifyMatchFailure(op, "unsupported ranks combination");

auto extractVecType = dyn_cast<VectorType>(op.getResult().getType());
int64_t finalRank = 0;
if (extractVecType)
finalRank = extractVecType.getRank();

SmallVector<Value> indices = loadOp.getIndices();
SmallVector<OpFoldResult> extractPos = op.getMixedPosition();

// There may be memory stores between the load and the extract op, so we
// need to make sure that the new load op is inserted at the same place as
// the original load op.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(loadOp);
Location loc = loadOp.getLoc();
ArithIndexingBuilder idxBuilderf(rewriter, loc);
for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
OpFoldResult pos = extractPos[i - rankOffset];
if (isZeroInteger(pos))
continue;

Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
indices[i] = idxBuilderf.add(indices[i], offset);
}

Value base = loadOp.getBase();
if (extractVecType) {
rewriter.replaceOpWithNewOp<vector::LoadOp>(op, extractVecType, base,
indices);
} else {
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
}
// We checked for single use so we can safely erase the load op.
rewriter.eraseOp(loadOp);
return success();
}
};

// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
public:
Expand Down Expand Up @@ -2363,7 +2456,9 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,

void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
results
.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractOpFromLoad>(
context);
results.add(foldExtractFromShapeCastToShapeCast);
results.add(foldExtractFromFromElements);
}
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2384,8 +2384,7 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
// TODO: Consider converting these patterns to canonicalizations.
patterns.add<ExtractOpFromLoad, StoreOpFromBroadcast>(patterns.getContext(),
benefit);
patterns.add<StoreOpFromBroadcast>(patterns.getContext(), benefit);
}

void mlir::vector::populateChainedVectorReductionFoldingPatterns(
Expand Down
131 changes: 131 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file | FileCheck %s

// This file contains some tests of folding/canonicalizing vector.extract

//-----------------------------------------------------------------------------
// [Pattern: ExtractOpFromLoad]
//-----------------------------------------------------------------------------

// CHECK-LABEL: @extract_load_scalar
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
func.func @extract_load_scalar(%arg0: memref<?xf32>, %arg1: index) -> f32 {
// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>
// CHECK: return %[[RES]] : f32
%0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
%1 = vector.extract %0[0] : f32 from vector<4xf32>
return %1 : f32
}

// CHECK-LABEL: @extract_load_index
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xindex>, %[[ARG1:.*]]: index)
func.func @extract_load_index(%arg0: memref<?xindex>, %arg1: index) -> index {
// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref<?xindex>
// CHECK: return %[[RES]] : index
%0 = vector.load %arg0[%arg1] : memref<?xindex>, vector<4xindex>
%1 = vector.extract %0[0] : index from vector<4xindex>
return %1 : index
}

// CHECK-LABEL: @extract_load_scalar_non_zero_off
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
func.func @extract_load_scalar_non_zero_off(%arg0: memref<?xf32>, %arg1: index) -> f32 {
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
// CHECK: return %[[RES]] : f32
%0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
%1 = vector.extract %0[1] : f32 from vector<4xf32>
return %1 : f32
}

// CHECK-LABEL: @extract_load_scalar_dyn_off
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
func.func @extract_load_scalar_dyn_off(%arg0: memref<?xf32>, %arg1: index, %arg2: index) -> f32 {
// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[ARG2]] overflow<nsw> : index
// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
// CHECK: return %[[RES]] : f32
%0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
%1 = vector.extract %0[%arg2] : f32 from vector<4xf32>
return %1 : f32
}

// CHECK-LABEL: @extract_load_vec_non_zero_off
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
func.func @extract_load_vec_non_zero_off(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> vector<4xf32> {
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[OFF]], %[[ARG2]]] : memref<?x?xf32>, vector<4xf32>
// CHECK: return %[[RES]] : vector<4xf32>
%0 = vector.load %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<2x4xf32>
%1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
return %1 : vector<4xf32>
}

// CHECK-LABEL: @extract_load_scalar_non_zero_off_2d_src_memref
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
func.func @extract_load_scalar_non_zero_off_2d_src_memref(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> f32 {
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[OFF]]] : memref<?x?xf32>
// CHECK: return %[[RES]] : f32
%0 = vector.load %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<4xf32>
%1 = vector.extract %0[1] : f32 from vector<4xf32>
return %1 : f32
}

// CHECK-LABEL: @extract_load_vec_high_rank
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
func.func @extract_load_vec_high_rank(%arg0: memref<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index) -> vector<4xf32> {
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[OFF]], %[[ARG3]]] : memref<?x?x?xf32>, vector<4xf32>
// CHECK: return %[[RES]] : vector<4xf32>
%0 = vector.load %arg0[%arg1, %arg2, %arg3] : memref<?x?x?xf32>, vector<2x4xf32>
%1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
return %1 : vector<4xf32>
}

// CHECK-LABEL: @negative_extract_load_scalar_from_memref_of_vec
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xvector<4xf32>>, %[[ARG1:.*]]: index)
func.func @negative_extract_load_scalar_from_memref_of_vec(%arg0: memref<?xvector<4xf32>>, %arg1: index) -> f32 {
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xvector<4xf32>>, vector<4xf32>
// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
// CHECK: return %[[EXT]] : f32
%0 = vector.load %arg0[%arg1] : memref<?xvector<4xf32>>, vector<4xf32>
%1 = vector.extract %0[0] : f32 from vector<4xf32>
return %1 : f32
}

// CHECK-LABEL: @negative_extract_load_scalar_from_memref_of_i1
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xi1>, %[[ARG1:.*]]: index)
func.func @negative_extract_load_scalar_from_memref_of_i1(%arg0: memref<?xi1>, %arg1: index) -> i1 {
// Subbyte types are tricky, ignore them for now.
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xi1>, vector<8xi1>
// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : i1 from vector<8xi1>
// CHECK: return %[[EXT]] : i1
%0 = vector.load %arg0[%arg1] : memref<?xi1>, vector<8xi1>
%1 = vector.extract %0[0] : i1 from vector<8xi1>
return %1 : i1
}

// CHECK-LABEL: @negative_extract_load_no_single_use
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
func.func @negative_extract_load_no_single_use(%arg0: memref<?xf32>, %arg1: index) -> (f32, vector<4xf32>) {
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
// CHECK: return %[[EXT]], %[[RES]] : f32, vector<4xf32>
%0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
%1 = vector.extract %0[0] : f32 from vector<4xf32>
return %1, %0 : f32, vector<4xf32>
}

// CHECK-LABEL: @negative_extract_load_scalable
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
func.func @negative_extract_load_scalable(%arg0: memref<?xf32>, %arg1: index) -> f32 {
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<[1]xf32>
// CHECK: return %[[EXT]] : f32
%0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<[1]xf32>
%1 = vector.extract %0[0] : f32 from vector<[1]xf32>
return %1 : f32
}
Loading
Loading