Skip to content

Commit 4b77e6b

Browse files
update function name and add canonicalize test and regression test.
1 parent 2d77859 commit 4b77e6b

File tree

4 files changed

+94
-58
lines changed

4 files changed

+94
-58
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1979,33 +1979,27 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
19791979

19801980
// If the dynamic operands of `extractOp` or `insertOp` is result of
19811981
// `constantOp`, then fold it.
1982-
template <typename T>
1983-
static LogicalResult foldConstantOp(T op, SmallVectorImpl<Value> &operands) {
1982+
template <typename OpType, typename AdaptorType>
1983+
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
1984+
SmallVectorImpl<Value> &operands) {
19841985
auto staticPosition = op.getStaticPosition().vec();
19851986
OperandRange dynamicPosition = op.getDynamicPosition();
1986-
1987+
ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
19871988
// If the dynamic operands is empty, it is returned directly.
19881989
if (!dynamicPosition.size())
1989-
return failure();
1990+
return {};
19901991
unsigned index = 0;
19911992

1992-
// `opChange` is a flog. If it is true, it means to update `op` in place.
1993+
// `opChange` is a flag. If it is true, it means to update `op` in place.
19931994
bool opChange = false;
19941995
for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
19951996
if (!ShapedType::isDynamic(staticPosition[i]))
19961997
continue;
1998+
Attribute positionAttr = dynamicPositionAttr[index];
19971999
Value position = dynamicPosition[index++];
1998-
1999-
// If it is a block parameter, proceed to the next iteration.
2000-
if (!position.getDefiningOp()) {
2001-
operands.push_back(position);
2002-
continue;
2003-
}
2004-
2005-
APInt pos;
2006-
if (matchPattern(position, m_ConstantInt(&pos))) {
2000+
if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2001+
staticPosition[i] = attr.getInt();
20072002
opChange = true;
2008-
staticPosition[i] = pos.getSExtValue();
20092003
continue;
20102004
}
20112005
operands.push_back(position);
@@ -2014,12 +2008,12 @@ static LogicalResult foldConstantOp(T op, SmallVectorImpl<Value> &operands) {
20142008
if (opChange) {
20152009
op.setStaticPosition(staticPosition);
20162010
op.getOperation()->setOperands(operands);
2017-
return success();
2011+
return op.getResult();
20182012
}
2019-
return failure();
2013+
return {};
20202014
}
20212015

2022-
OpFoldResult ExtractOp::fold(FoldAdaptor) {
2016+
OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
20232017
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
20242018
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
20252019
// mismatch).
@@ -2042,8 +2036,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) {
20422036
if (auto val = foldScalarExtractFromFromElements(*this))
20432037
return val;
20442038
SmallVector<Value> operands = {getVector()};
2045-
if (succeeded(foldConstantOp(*this, operands)))
2046-
return getResult();
2039+
if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
2040+
return val;
20472041
return OpFoldResult();
20482042
}
20492043

@@ -3074,8 +3068,8 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
30743068
if (getNumIndices() == 0 && getSourceType() == getType())
30753069
return getSource();
30763070
SmallVector<Value> operands = {getSource(), getDest()};
3077-
if (succeeded(foldConstantOp(*this, operands)))
3078-
return getResult();
3071+
if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
3072+
return val;
30793073
return {};
30803074
}
30813075

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4115,39 +4115,3 @@ func.func @step_scalable() -> vector<[4]xindex> {
41154115
%0 = vector.step : vector<[4]xindex>
41164116
return %0 : vector<[4]xindex>
41174117
}
4118-
4119-
// -----
4120-
4121-
// CHECK-LABEL: @extract_arith_constnt
4122-
func.func @extract_arith_constnt() -> i32 {
4123-
%v = arith.constant dense<0> : vector<32x1xi32>
4124-
%c_0 = arith.constant 0 : index
4125-
%elem = vector.extract %v[%c_0, %c_0] : i32 from vector<32x1xi32>
4126-
return %elem : i32
4127-
}
4128-
4129-
// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
4130-
// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
4131-
// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
4132-
// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(0 : i64) : i64
4133-
// CHECK: %{{.*}} = llvm.extractelement %[[VAL_2]]{{\[}}%[[VAL_3]] : i64] : vector<1xi32>
4134-
4135-
// -----
4136-
4137-
// CHECK-LABEL: @insert_arith_constnt()
4138-
4139-
func.func @insert_arith_constnt() -> vector<32x1xi32> {
4140-
%v = arith.constant dense<0> : vector<32x1xi32>
4141-
%c_0 = arith.constant 0 : index
4142-
%c_1 = arith.constant 1 : i32
4143-
%v_1 = vector.insert %c_1, %v[%c_0, %c_0] : i32 into vector<32x1xi32>
4144-
return %v_1 : vector<32x1xi32>
4145-
}
4146-
4147-
// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
4148-
// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
4149-
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
4150-
// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
4151-
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i64) : i64
4152-
// CHECK: %[[VAL_5:.*]] = llvm.insertelement %[[VAL_2]], %[[VAL_3]]{{\[}}%[[VAL_4]] : i64] : vector<1xi32>
4153-
// CHECK: %{{.*}} = llvm.insertvalue %[[VAL_5]], %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2979,3 +2979,47 @@ func.func @contiguous_scatter_step(%base: memref<?xf32>,
29792979
memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32>
29802980
return
29812981
}
2982+
2983+
// -----
2984+
2985+
// CHECK-LABEL: func @extract_arith_constnt
2986+
2987+
func.func @extract_arith_constnt() -> i32 {
2988+
%c1_i32 = arith.constant 1 : i32
2989+
return %c1_i32 : i32
2990+
}
2991+
2992+
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : i32
2993+
// CHECK: return %[[VAL_0]] : i32
2994+
2995+
// -----
2996+
2997+
// CHECK-LABEL: func @insert_arith_constnt
2998+
2999+
func.func @insert_arith_constnt() -> vector<4x1xi32> {
3000+
%v = arith.constant dense<0> : vector<4x1xi32>
3001+
%c_0 = arith.constant 0 : index
3002+
%c_1 = arith.constant 1 : i32
3003+
%v_1 = vector.insert %c_1, %v[%c_0, %c_0] : i32 into vector<4x1xi32>
3004+
return %v_1 : vector<4x1xi32>
3005+
}
3006+
3007+
// CHECK: %[[VAL_0:.*]] = arith.constant dense<{{\[\[}}1], [0], [0], [0]]> : vector<4x1xi32>
3008+
// CHECK: return %[[VAL_0]] : vector<4x1xi32>
3009+
3010+
// -----
3011+
3012+
// CHECK-LABEL: func @insert_extract_arith_constnt
3013+
3014+
func.func @insert_extract_arith_constnt() -> i32 {
3015+
%v = arith.constant dense<0> : vector<32x1xi32>
3016+
%c_0 = arith.constant 0 : index
3017+
%c_1 = arith.constant 1 : index
3018+
%c_2 = arith.constant 2 : i32
3019+
%v_1 = vector.insert %c_2, %v[%c_1, %c_1] : i32 into vector<32x1xi32>
3020+
%ret = vector.extract %v_1[%c_1, %c_1] : i32 from vector<32x1xi32>
3021+
return %ret : i32
3022+
}
3023+
3024+
// CHECK: %[[VAL_0:.*]] = arith.constant 2 : i32
3025+
// CHECK: return %[[VAL_0]] : i32
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// RUN: mlir-opt %s -test-lower-to-llvm | \
2+
// RUN: mlir-runner -e entry -entry-point-result=void \
3+
// RUN: -shared-libs=%mlir_c_runner_utils | \
4+
// RUN: FileCheck %s
5+
6+
func.func @entry() {
7+
%v = arith.constant dense<0> : vector<2x2xi32>
8+
%c_0 = arith.constant 0 : index
9+
%c_1 = arith.constant 1 : index
10+
%i32_0 = arith.constant 0 : i32
11+
%i32_1 = arith.constant 1 : i32
12+
%i32_2 = arith.constant 2 : i32
13+
%i32_3 = arith.constant 3 : i32
14+
%v_1 = vector.insert %i32_0, %v[%c_0, %c_0] : i32 into vector<2x2xi32>
15+
%v_2 = vector.insert %i32_1, %v_1[%c_0, %c_1] : i32 into vector<2x2xi32>
16+
%v_3 = vector.insert %i32_2, %v_2[%c_1, %c_0] : i32 into vector<2x2xi32>
17+
%v_4 = vector.insert %i32_3, %v_3[%c_1, %c_1] : i32 into vector<2x2xi32>
18+
// CHECK: ( ( 0, 1 ), ( 2, 3 ) )
19+
vector.print %v_4 : vector<2x2xi32>
20+
%v_5 = arith.constant dense<[[0, 1], [2, 3]]> : vector<2x2xi32>
21+
// CHECK: 0
22+
%i32_4 = vector.extract %v_5[%c_0, %c_0] : i32 from vector<2x2xi32>
23+
// CHECK: 1
24+
%i32_5 = vector.extract %v_5[%c_0, %c_1] : i32 from vector<2x2xi32>
25+
// CHECK: 2
26+
%i32_6 = vector.extract %v_5[%c_1, %c_0] : i32 from vector<2x2xi32>
27+
// CHECK: 3
28+
%i32_7 = vector.extract %v_5[%c_1, %c_1] : i32 from vector<2x2xi32>
29+
vector.print %i32_4 : i32
30+
vector.print %i32_5 : i32
31+
vector.print %i32_6 : i32
32+
vector.print %i32_7 : i32
33+
return
34+
}

0 commit comments

Comments
 (0)