Skip to content

Commit 2f57f48

Browse files
committed
[mlir][Vector][NFC] Move canonicalizers for DenseElementsAttr to folders
1 parent 544a161 commit 2f57f48

File tree

5 files changed

+62
-109
lines changed

5 files changed

+62
-109
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 46 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -2047,6 +2047,49 @@ static Attribute foldPoisonSrcExtractOp(Attribute srcAttr) {
20472047
return {};
20482048
}
20492049

2050+
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp,
2051+
Attribute srcAttr) {
2052+
auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2053+
if (!denseAttr) {
2054+
return {};
2055+
}
2056+
2057+
if (denseAttr.isSplat()) {
2058+
Attribute newAttr = denseAttr.getSplatValue<Attribute>();
2059+
if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
2060+
newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2061+
return newAttr;
2062+
}
2063+
2064+
auto vecTy = llvm::cast<VectorType>(extractOp.getSourceVectorType());
2065+
if (vecTy.isScalable())
2066+
return {};
2067+
2068+
if (extractOp.hasDynamicPosition()) {
2069+
return {};
2070+
}
2071+
2072+
// Calculate the linearized position of the continuous chunk of elements to
2073+
// extract.
2074+
llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2075+
copy(extractOp.getStaticPosition(), completePositions.begin());
2076+
int64_t elemBeginPosition =
2077+
linearize(completePositions, computeStrides(vecTy.getShape()));
2078+
auto denseValuesBegin =
2079+
denseAttr.value_begin<TypedAttr>() + elemBeginPosition;
2080+
2081+
TypedAttr newAttr;
2082+
if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
2083+
SmallVector<Attribute> elementValues(
2084+
denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2085+
newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2086+
} else {
2087+
newAttr = *denseValuesBegin;
2088+
}
2089+
2090+
return newAttr;
2091+
}
2092+
20502093
OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
20512094
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
20522095
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -2058,6 +2101,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
20582101
return res;
20592102
if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
20602103
return res;
2104+
if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getVector()))
2105+
return res;
20612106
if (succeeded(foldExtractOpFromExtractChain(*this)))
20622107
return getResult();
20632108
if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
@@ -2121,80 +2166,6 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
21212166
}
21222167
};
21232168

2124-
// Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
2125-
class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
2126-
public:
2127-
using OpRewritePattern::OpRewritePattern;
2128-
2129-
LogicalResult matchAndRewrite(ExtractOp extractOp,
2130-
PatternRewriter &rewriter) const override {
2131-
// Return if 'ExtractOp' operand is not defined by a splat vector
2132-
// ConstantOp.
2133-
Value sourceVector = extractOp.getVector();
2134-
Attribute vectorCst;
2135-
if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
2136-
return failure();
2137-
auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
2138-
if (!splat)
2139-
return failure();
2140-
TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
2141-
if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
2142-
newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2143-
rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
2144-
return success();
2145-
}
2146-
};
2147-
2148-
// Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
2149-
class ExtractOpNonSplatConstantFolder final
2150-
: public OpRewritePattern<ExtractOp> {
2151-
public:
2152-
using OpRewritePattern::OpRewritePattern;
2153-
2154-
LogicalResult matchAndRewrite(ExtractOp extractOp,
2155-
PatternRewriter &rewriter) const override {
2156-
// TODO: Canonicalization for dynamic position not implemented yet.
2157-
if (extractOp.hasDynamicPosition())
2158-
return failure();
2159-
2160-
// Return if 'ExtractOp' operand is not defined by a compatible vector
2161-
// ConstantOp.
2162-
Value sourceVector = extractOp.getVector();
2163-
Attribute vectorCst;
2164-
if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
2165-
return failure();
2166-
2167-
auto vecTy = llvm::cast<VectorType>(sourceVector.getType());
2168-
if (vecTy.isScalable())
2169-
return failure();
2170-
2171-
// The splat case is handled by `ExtractOpSplatConstantFolder`.
2172-
auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
2173-
if (!dense || dense.isSplat())
2174-
return failure();
2175-
2176-
// Calculate the linearized position of the continuous chunk of elements to
2177-
// extract.
2178-
llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2179-
copy(extractOp.getStaticPosition(), completePositions.begin());
2180-
int64_t elemBeginPosition =
2181-
linearize(completePositions, computeStrides(vecTy.getShape()));
2182-
auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
2183-
2184-
TypedAttr newAttr;
2185-
if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
2186-
SmallVector<Attribute> elementValues(
2187-
denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2188-
newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2189-
} else {
2190-
newAttr = *denseValuesBegin;
2191-
}
2192-
2193-
rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
2194-
return success();
2195-
}
2196-
};
2197-
21982169
// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
21992170
class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
22002171
public:
@@ -2332,8 +2303,7 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
23322303

23332304
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
23342305
MLIRContext *context) {
2335-
results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
2336-
ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2306+
results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
23372307
results.add(foldExtractFromShapeCastToShapeCast);
23382308
results.add(foldExtractFromFromElements);
23392309
}

mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,8 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic(
3232

3333
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
3434
// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
35-
// CHECK-DAG: %[[CST_0:.+]] = arith.constant dense<0> : vector<1xindex>
36-
// CHECK-DAG: %[[CST_1:.+]] = arith.constant dense<[0, 1, 2]> : vector<3xindex>
3735

38-
// CHECK-DAG: %[[IDX1:.+]] = vector.extract %[[CST_0]][0] : index from vector<1xindex>
39-
// CHECK-DAG: %[[IDX2:.+]] = vector.extract %[[CST_0]][0] : index from vector<1xindex>
40-
// CHECK-DAG: %[[IDX3:.+]] = vector.extract %[[CST_1]][0] : index from vector<3xindex>
41-
42-
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[IDX3]]], %[[CST]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
36+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]], %[[CST]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
4337
// CHECK: vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>
4438

4539
// -----
@@ -175,16 +169,12 @@ func.func @vectorize_nd_tensor_extract_with_maxsi_contiguous(%arg0: tensor<80x16
175169
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_maxsi_contiguous(
176170
// CHECK-SAME: %[[VAL_0:.*]]: tensor<80x16xf32>,
177171
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
178-
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
179-
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
180-
181-
// CHECK-DAG: %[[CST_0:.+]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
182-
// CHECK-DAG: %[[CST_1:.+]] = arith.constant dense<16> : vector<4x1xindex>
183-
// CHECK-DAG: %[[IDX0:.+]] = vector.extract %[[CST_1]][0, 0] : index from vector<4x1xindex>
184-
// CHECK-DAG: %[[IDX1:.+]] = vector.extract %[[CST_0]][0] : index from vector<4xindex>
172+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
173+
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
174+
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
185175

186-
// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[IDX0]], %[[IDX1]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
187-
// CHECK: %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_1]]{{\[}}%[[VAL_4]], %[[VAL_4]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
176+
// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[C16]], %[[C0]]], %[[CST]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
177+
// CHECK: %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_1]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
188178
// CHECK: return %[[VAL_9]] : tensor<1x4xf32>
189179
// CHECK: }
190180

@@ -675,9 +665,7 @@ func.func @scalar_read_with_broadcast_from_column_tensor(%init: tensor<1x1x4xi32
675665
// CHECK-DAG: %[[PAD:.*]] = arith.constant 0 : i32
676666
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
677667
// CHECK-DAG: %[[SRC:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
678-
// CHECK-DAG: %[[IDX_VEC:.*]] = arith.constant dense<0> : vector<1xindex>
679-
// CHECK: %[[IDX_ELT:.*]] = vector.extract %[[IDX_VEC]][0] : index from vector<1xindex>
680-
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{\[}}%[[IDX_ELT]], %[[C0]]], %[[PAD]] : tensor<15x1xi32>, vector<i32>
668+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : tensor<15x1xi32>, vector<i32>
681669
// CHECK: %[[READ_BCAST:.*]] = vector.broadcast %[[READ]] : vector<i32> to vector<1x1x4xi32>
682670
// CHECK: %[[RES:.*]] = vector.transfer_write %[[READ_BCAST]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x4xi32>, tensor<1x1x4xi32>
683671

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,12 +310,12 @@ func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector
310310
// -----
311311

312312
// ALL-LABEL: test_vector_extract_scalar
313-
func.func @test_vector_extract_scalar() {
313+
func.func @test_vector_extract_scalar(%idx : index) {
314314
%cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32>
315315
// ALL-NOT: vector.shuffle
316316
// ALL: vector.extract
317317
// ALL-NOT: vector.shuffle
318-
%0 = vector.extract %cst[0] : i32 from vector<4xi32>
318+
%0 = vector.extract %cst[%idx] : i32 from vector<4xi32>
319319
return
320320
}
321321

mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,8 @@ func.func @transfer_read_2d_extract(%m: memref<?x?x?x?xf32>, %idx: index, %idx2:
101101

102102
// CHECK-LABEL: func @transfer_write_arith_constant(
103103
// CHECK-SAME: %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index
104-
// CHECK: %[[cst:.*]] = arith.constant dense<5.000000e+00> : vector<1x1xf32>
105-
// CHECK: %[[extract:.*]] = vector.extract %[[cst]][0, 0] : f32 from vector<1x1xf32>
106-
// CHECK: memref.store %[[extract]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
104+
// CHECK: %[[cst:.*]] = arith.constant 5.000000e+00 : f32
105+
// CHECK: memref.store %[[cst]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
107106
func.func @transfer_write_arith_constant(%m: memref<?x?x?xf32>, %idx: index) {
108107
%cst = arith.constant dense<5.000000e+00> : vector<1x1xf32>
109108
vector.transfer_write %cst, %m[%idx, %idx, %idx] : vector<1x1xf32>, memref<?x?x?xf32>

mlir/test/Dialect/Vector/vector-gather-lowering.mlir

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -242,33 +242,29 @@ func.func @strided_gather(%base : memref<100x3xf32>,
242242
// CHECK-SAME: %[[IDXS:.*]]: vector<4xindex>,
243243
// CHECK-SAME: %[[VAL_4:.*]]: index,
244244
// CHECK-SAME: %[[VAL_5:.*]]: index) -> vector<4xf32> {
245+
// CHECK: %[[TRUE:.*]] = arith.constant true
245246
// CHECK: %[[CST_3:.*]] = arith.constant dense<3> : vector<4xindex>
246-
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<4xi1>
247247

248248
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[base]] {{\[\[}}0, 1]] : memref<100x3xf32> into memref<300xf32>
249249
// CHECK: %[[NEW_IDXS:.*]] = arith.muli %[[IDXS]], %[[CST_3]] : vector<4xindex>
250250

251-
// CHECK: %[[MASK_0:.*]] = vector.extract %[[MASK]][0] : i1 from vector<4xi1>
252251
// CHECK: %[[IDX_0:.*]] = vector.extract %[[NEW_IDXS]][0] : index from vector<4xindex>
253-
// CHECK: scf.if %[[MASK_0]] -> (vector<4xf32>)
252+
// CHECK: scf.if %[[TRUE]] -> (vector<4xf32>)
254253
// CHECK: %[[M_0:.*]] = vector.load %[[COLLAPSED]][%[[IDX_0]]] : memref<300xf32>, vector<1xf32>
255254
// CHECK: %[[V_0:.*]] = vector.extract %[[M_0]][0] : f32 from vector<1xf32>
256255

257-
// CHECK: %[[MASK_1:.*]] = vector.extract %[[MASK]][1] : i1 from vector<4xi1>
258256
// CHECK: %[[IDX_1:.*]] = vector.extract %[[NEW_IDXS]][1] : index from vector<4xindex>
259-
// CHECK: scf.if %[[MASK_1]] -> (vector<4xf32>)
257+
// CHECK: scf.if %[[TRUE]] -> (vector<4xf32>)
260258
// CHECK: %[[M_1:.*]] = vector.load %[[COLLAPSED]][%[[IDX_1]]] : memref<300xf32>, vector<1xf32>
261259
// CHECK: %[[V_1:.*]] = vector.extract %[[M_1]][0] : f32 from vector<1xf32>
262260

263-
// CHECK: %[[MASK_2:.*]] = vector.extract %[[MASK]][2] : i1 from vector<4xi1>
264261
// CHECK: %[[IDX_2:.*]] = vector.extract %[[NEW_IDXS]][2] : index from vector<4xindex>
265-
// CHECK: scf.if %[[MASK_2]] -> (vector<4xf32>)
262+
// CHECK: scf.if %[[TRUE]] -> (vector<4xf32>)
266263
// CHECK: %[[M_2:.*]] = vector.load %[[COLLAPSED]][%[[IDX_2]]] : memref<300xf32>, vector<1xf32>
267264
// CHECK: %[[V_2:.*]] = vector.extract %[[M_2]][0] : f32 from vector<1xf32>
268265

269-
// CHECK: %[[MASK_3:.*]] = vector.extract %[[MASK]][3] : i1 from vector<4xi1>
270266
// CHECK: %[[IDX_3:.*]] = vector.extract %[[NEW_IDXS]][3] : index from vector<4xindex>
271-
// CHECK: scf.if %[[MASK_3]] -> (vector<4xf32>)
267+
// CHECK: scf.if %[[TRUE]] -> (vector<4xf32>)
272268
// CHECK: %[[M_3:.*]] = vector.load %[[COLLAPSED]][%[[IDX_3]]] : memref<300xf32>, vector<1xf32>
273269
// CHECK: %[[V_3:.*]] = vector.extract %[[M_3]][0] : f32 from vector<1xf32>
274270

0 commit comments

Comments
 (0)