Skip to content

Commit fa488d5

Browse files
support extract extract.
1 parent 82fecab commit fa488d5

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,6 +1096,29 @@ class VectorExtractOpConversion
10961096
SmallVector<OpFoldResult> positionVec = getMixedValues(
10971097
adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
10981098

1099+
for (unsigned idx = 0; idx < positionVec.size(); ++idx) {
1100+
if (auto position = llvm::dyn_cast<Value>(positionVec[idx])) {
1101+
auto defOp = position.getDefiningOp();
1102+
while (true) {
1103+
if (!defOp) {
1104+
break;
1105+
}
1106+
if (llvm::isa<arith::ConstantOp, LLVM::ConstantOp>(defOp)) {
1107+
Attribute value =
1108+
defOp->getAttr(arith::ConstantOp::getAttributeNames()[0]);
1109+
positionVec[idx] = OpFoldResult{
1110+
rewriter.getI64IntegerAttr(cast<IntegerAttr>(value).getInt())};
1111+
break;
1112+
} else if (auto unrealizedCastOp =
1113+
llvm::dyn_cast<UnrealizedConversionCastOp>(defOp)) {
1114+
defOp = unrealizedCastOp.getOperand(0).getDefiningOp();
1115+
} else {
1116+
break;
1117+
}
1118+
}
1119+
}
1120+
}
1121+
10991122
// The Vector -> LLVM lowering models N-D vectors as nested aggregates of
11001123
// 1-d vectors. This nesting is modeled using arrays. We do this conversion
11011124
// from a N-d vector extract to a nested aggregate vector extract in two

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4094,3 +4094,43 @@ func.func @step_scalable() -> vector<[4]xindex> {
40944094
%0 = vector.step : vector<[4]xindex>
40954095
return %0 : vector<[4]xindex>
40964096
}
4097+
4098+
// -----
4099+
4100+
// CHECK-LABEL: @extract_arith_constnt
4101+
func.func @extract_arith_constnt() -> i32 {
4102+
%v = arith.constant dense<0> : vector<32x1xi32>
4103+
%c_0 = arith.constant 0 : index
4104+
%elem = vector.extract %v[%c_0, %c_0] : i32 from vector<32x1xi32>
4105+
return %elem : i32
4106+
}
4107+
4108+
// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
4109+
// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
4110+
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
4111+
// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
4112+
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i64) : i64
4113+
// CHECK: %[[VAL_5:.*]] = llvm.extractelement %[[VAL_3]]{{\[}}%[[VAL_4]] : i64] : vector<1xi32>
4114+
// CHECK: return %[[VAL_5]] : i32
4115+
4116+
// -----
4117+
4118+
// CHECK-LABEL: @extract_llvm_constnt()
4119+
4120+
module {
4121+
func.func @extract_llvm_constnt() -> i32 {
4122+
%0 = llvm.mlir.constant(dense<0> : vector<32x1xi32>) : !llvm.array<32 x vector<1xi32>>
4123+
%1 = builtin.unrealized_conversion_cast %0 : !llvm.array<32 x vector<1xi32>> to vector<32x1xi32>
4124+
%2 = llvm.mlir.constant(0 : index) : i64
4125+
%3 = builtin.unrealized_conversion_cast %2 : i64 to index
4126+
%4 = vector.extract %1[%3, %3] : i32 from vector<32x1xi32>
4127+
return %4 : i32
4128+
}
4129+
}
4130+
4131+
// CHECK: %[[VAL_0:.*]] = llvm.mlir.constant(0 : index) : i64
4132+
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(dense<0> : vector<32x1xi32>) : !llvm.array<32 x vector<1xi32>>
4133+
// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
4134+
// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(0 : i64) : i64
4135+
// CHECK: %[[VAL_4:.*]] = llvm.extractelement %[[VAL_2]]{{\[}}%[[VAL_3]] : i64] : vector<1xi32>
4136+
// CHECK: return %[[VAL_4]] : i32

0 commit comments

Comments
 (0)