Skip to content

Commit bc71c3b

Browse files
committed
[mlir][vector] Turn ExtractOpFromLoad into a canonicalization
This addresses a TODO and an earlier review comment from the original PR where the pattern was introduced: * #134389 (comment) The pattern is relatively straightforward and has not been updated since it landed, so it seems reasonable to promote it to a canonicalization. Note: this change only moves the existing pattern into canonicalization infrastructure; it does not add or remove any functionality.
1 parent fa511cd commit bc71c3b

File tree

7 files changed

+233
-133
lines changed

7 files changed

+233
-133
lines changed

mlir/include/mlir/Dialect/Vector/IR/Vector.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ def Vector_Dialect : Dialect {
2323
let hasConstantMaterializer = 1;
2424
let dependentDialects = [
2525
"arith::ArithDialect",
26-
"ub::UBDialect"
26+
"ub::UBDialect",
27+
"memref::MemRefDialect"
2728
];
2829
}
2930

mlir/lib/Dialect/Vector/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRVectorDialect
1515
LINK_LIBS PUBLIC
1616
MLIRAffineDialect
1717
MLIRArithDialect
18+
MLIRMemRefDialect
1819
MLIRControlFlowInterfaces
1920
MLIRDataLayoutInterfaces
2021
MLIRDestinationStyleOpInterface

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2226,6 +2226,99 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
22262226
}
22272227
};
22282228

2229+
/// Check if the element type is suitable for vector.load/store sinking.
2230+
/// Element type must be index or byte-aligned integer or floating-point type.
2231+
static bool isSupportedMemSinkElementType(Type type) {
2232+
if (isa<IndexType>(type))
2233+
return true;
2234+
2235+
return type.isIntOrFloat() && type.getIntOrFloatBitWidth() % 8 == 0;
2236+
}
2237+
2238+
/// Pattern to rewrite `vector.extract(vector.load) -> vector/memref.load.
2239+
/// Only index and byte-aligned integer and floating-point element types are
2240+
/// supported for now.
2241+
///
2242+
/// Example:
2243+
/// ```
2244+
/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
2245+
/// vector.extract %0[1] : f32 from vector<4xf32>
2246+
/// ```
2247+
/// Gets converted to:
2248+
/// ```
2249+
/// %c1 = arith.constant 1 : index
2250+
/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
2251+
/// %1 = memref.load %arg0[%0] : memref<?xf32>
2252+
/// ```
2253+
///
2254+
/// Note, this is consider beneficial only in single-use cases.
2255+
class ExtractOpFromLoad final : public OpRewritePattern<ExtractOp> {
2256+
public:
2257+
using Base::Base;
2258+
2259+
LogicalResult matchAndRewrite(vector::ExtractOp op,
2260+
PatternRewriter &rewriter) const override {
2261+
auto loadOp = op.getSource().getDefiningOp<vector::LoadOp>();
2262+
if (!loadOp)
2263+
return rewriter.notifyMatchFailure(op, "expected a load op");
2264+
2265+
// Checking for single use so we won't duplicate load ops.
2266+
if (!loadOp->hasOneUse())
2267+
return rewriter.notifyMatchFailure(op, "expected single op use");
2268+
2269+
VectorType loadVecType = loadOp.getVectorType();
2270+
if (loadVecType.isScalable())
2271+
return rewriter.notifyMatchFailure(op,
2272+
"scalable vectors are not supported");
2273+
2274+
MemRefType memType = loadOp.getMemRefType();
2275+
2276+
// Non-byte-aligned types are tricky and may require special handling,
2277+
// ignore them for now.
2278+
if (!isSupportedMemSinkElementType(memType.getElementType()))
2279+
return rewriter.notifyMatchFailure(op, "unsupported element type");
2280+
2281+
int64_t rankOffset = memType.getRank() - loadVecType.getRank();
2282+
if (rankOffset < 0)
2283+
return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
2284+
2285+
auto extractVecType = dyn_cast<VectorType>(op.getResult().getType());
2286+
int64_t finalRank = 0;
2287+
if (extractVecType)
2288+
finalRank = extractVecType.getRank();
2289+
2290+
SmallVector<Value> indices = loadOp.getIndices();
2291+
SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
2292+
2293+
// There may be memory stores between the load and the extract op, so we
2294+
// need to make sure that the new load op is inserted at the same place as
2295+
// the original load op.
2296+
OpBuilder::InsertionGuard g(rewriter);
2297+
rewriter.setInsertionPoint(loadOp);
2298+
Location loc = loadOp.getLoc();
2299+
ArithIndexingBuilder idxBuilderf(rewriter, loc);
2300+
for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
2301+
OpFoldResult pos = extractPos[i - rankOffset];
2302+
if (isZeroInteger(pos))
2303+
continue;
2304+
2305+
Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
2306+
indices[i] = idxBuilderf.add(indices[i], offset);
2307+
}
2308+
2309+
Value base = loadOp.getBase();
2310+
if (extractVecType) {
2311+
rewriter.replaceOpWithNewOp<vector::LoadOp>(op, extractVecType, base,
2312+
indices);
2313+
} else {
2314+
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
2315+
}
2316+
// We checked for single use so we can safely erase the load op.
2317+
rewriter.eraseOp(loadOp);
2318+
return success();
2319+
}
2320+
};
2321+
22292322
// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
22302323
class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
22312324
public:
@@ -2363,7 +2456,9 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
23632456

23642457
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
23652458
MLIRContext *context) {
2366-
results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2459+
results
2460+
.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractOpFromLoad>(
2461+
context);
23672462
results.add(foldExtractFromShapeCastToShapeCast);
23682463
results.add(foldExtractFromFromElements);
23692464
}

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2384,8 +2384,7 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
23842384
void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
23852385
PatternBenefit benefit) {
23862386
// TODO: Consider converting these patterns to canonicalizations.
2387-
patterns.add<ExtractOpFromLoad, StoreOpFromBroadcast>(patterns.getContext(),
2388-
benefit);
2387+
patterns.add<StoreOpFromBroadcast>(patterns.getContext(), benefit);
23892388
}
23902389

23912390
void mlir::vector::populateChainedVectorReductionFoldingPatterns(
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file | FileCheck %s
2+
3+
// This file contains some tests of folding/canonicalizing vector.extract
4+
5+
//-----------------------------------------------------------------------------
6+
// [Pattern: ExtractOpFromLoad]
7+
//-----------------------------------------------------------------------------
8+
9+
// CHECK-LABEL: @extract_load_scalar
10+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
11+
func.func @extract_load_scalar(%arg0: memref<?xf32>, %arg1: index) -> f32 {
12+
// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>
13+
// CHECK: return %[[RES]] : f32
14+
%0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
15+
%1 = vector.extract %0[0] : f32 from vector<4xf32>
16+
return %1 : f32
17+
}
18+
19+
// CHECK-LABEL: @extract_load_index
20+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xindex>, %[[ARG1:.*]]: index)
21+
func.func @extract_load_index(%arg0: memref<?xindex>, %arg1: index) -> index {
22+
// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref<?xindex>
23+
// CHECK: return %[[RES]] : index
24+
%0 = vector.load %arg0[%arg1] : memref<?xindex>, vector<4xindex>
25+
%1 = vector.extract %0[0] : index from vector<4xindex>
26+
return %1 : index
27+
}
28+
29+
// CHECK-LABEL: @extract_load_scalar_non_zero_off
30+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
31+
func.func @extract_load_scalar_non_zero_off(%arg0: memref<?xf32>, %arg1: index) -> f32 {
32+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
33+
// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
34+
// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
35+
// CHECK: return %[[RES]] : f32
36+
%0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
37+
%1 = vector.extract %0[1] : f32 from vector<4xf32>
38+
return %1 : f32
39+
}
40+
41+
// CHECK-LABEL: @extract_load_scalar_dyn_off
42+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
43+
func.func @extract_load_scalar_dyn_off(%arg0: memref<?xf32>, %arg1: index, %arg2: index) -> f32 {
44+
// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[ARG2]] overflow<nsw> : index
45+
// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
46+
// CHECK: return %[[RES]] : f32
47+
%0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
48+
%1 = vector.extract %0[%arg2] : f32 from vector<4xf32>
49+
return %1 : f32
50+
}
51+
52+
// CHECK-LABEL: @extract_load_vec_non_zero_off
53+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
54+
func.func @extract_load_vec_non_zero_off(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> vector<4xf32> {
55+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
56+
// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
57+
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[OFF]], %[[ARG2]]] : memref<?x?xf32>, vector<4xf32>
58+
// CHECK: return %[[RES]] : vector<4xf32>
59+
%0 = vector.load %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<2x4xf32>
60+
%1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
61+
return %1 : vector<4xf32>
62+
}
63+
64+
// CHECK-LABEL: @extract_load_scalar_non_zero_off_2d_src_memref
65+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
66+
func.func @extract_load_scalar_non_zero_off_2d_src_memref(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> f32 {
67+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
68+
// CHECK: %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
69+
// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[OFF]]] : memref<?x?xf32>
70+
// CHECK: return %[[RES]] : f32
71+
%0 = vector.load %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<4xf32>
72+
%1 = vector.extract %0[1] : f32 from vector<4xf32>
73+
return %1 : f32
74+
}
75+
76+
// CHECK-LABEL: @extract_load_vec_high_rank
77+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
78+
func.func @extract_load_vec_high_rank(%arg0: memref<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index) -> vector<4xf32> {
79+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
80+
// CHECK: %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
81+
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[OFF]], %[[ARG3]]] : memref<?x?x?xf32>, vector<4xf32>
82+
// CHECK: return %[[RES]] : vector<4xf32>
83+
%0 = vector.load %arg0[%arg1, %arg2, %arg3] : memref<?x?x?xf32>, vector<2x4xf32>
84+
%1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
85+
return %1 : vector<4xf32>
86+
}
87+
88+
// CHECK-LABEL: @negative_extract_load_scalar_from_memref_of_vec
89+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xvector<4xf32>>, %[[ARG1:.*]]: index)
90+
func.func @negative_extract_load_scalar_from_memref_of_vec(%arg0: memref<?xvector<4xf32>>, %arg1: index) -> f32 {
91+
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xvector<4xf32>>, vector<4xf32>
92+
// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
93+
// CHECK: return %[[EXT]] : f32
94+
%0 = vector.load %arg0[%arg1] : memref<?xvector<4xf32>>, vector<4xf32>
95+
%1 = vector.extract %0[0] : f32 from vector<4xf32>
96+
return %1 : f32
97+
}
98+
99+
// CHECK-LABEL: @negative_extract_load_scalar_from_memref_of_i1
100+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xi1>, %[[ARG1:.*]]: index)
101+
func.func @negative_extract_load_scalar_from_memref_of_i1(%arg0: memref<?xi1>, %arg1: index) -> i1 {
102+
// Subbyte types are tricky, ignore them for now.
103+
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xi1>, vector<8xi1>
104+
// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : i1 from vector<8xi1>
105+
// CHECK: return %[[EXT]] : i1
106+
%0 = vector.load %arg0[%arg1] : memref<?xi1>, vector<8xi1>
107+
%1 = vector.extract %0[0] : i1 from vector<8xi1>
108+
return %1 : i1
109+
}
110+
111+
// CHECK-LABEL: @negative_extract_load_no_single_use
112+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
113+
func.func @negative_extract_load_no_single_use(%arg0: memref<?xf32>, %arg1: index) -> (f32, vector<4xf32>) {
114+
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
115+
// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
116+
// CHECK: return %[[EXT]], %[[RES]] : f32, vector<4xf32>
117+
%0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
118+
%1 = vector.extract %0[0] : f32 from vector<4xf32>
119+
return %1, %0 : f32, vector<4xf32>
120+
}
121+
122+
// CHECK-LABEL: @negative_extract_load_scalable
123+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
124+
func.func @negative_extract_load_scalable(%arg0: memref<?xf32>, %arg1: index) -> f32 {
125+
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
126+
// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<[1]xf32>
127+
// CHECK: return %[[EXT]] : f32
128+
%0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<[1]xf32>
129+
%1 = vector.extract %0[0] : f32 from vector<[1]xf32>
130+
return %1 : f32
131+
}

0 commit comments

Comments
 (0)