Skip to content
48 changes: 48 additions & 0 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1977,6 +1977,48 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
return fromElementsOp.getElements()[flatIndex];
}

// If the dynamic operands of `extractOp` or `insertOp` is result of
// `constantOp`, then fold it.
template <typename T>
static LogicalResult foldConstantOp(T op, SmallVectorImpl<Value> &operands) {
auto staticPosition = op.getStaticPosition().vec();
OperandRange dynamicPosition = op.getDynamicPosition();

// If the dynamic operands is empty, it is returned directly.
if (!dynamicPosition.size())
return failure();
unsigned index = 0;

// `opChange` is a flog. If it is true, it means to update `op` in place.
bool opChange = false;
for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
if (!ShapedType::isDynamic(staticPosition[i]))
continue;
Value position = dynamicPosition[index++];

// If it is a block parameter, proceed to the next iteration.
if (!position.getDefiningOp()) {
operands.push_back(position);
continue;
}

APInt pos;
if (matchPattern(position, m_ConstantInt(&pos))) {
opChange = true;
staticPosition[i] = pos.getSExtValue();
continue;
}
operands.push_back(position);
}

if (opChange) {
op.setStaticPosition(staticPosition);
op.getOperation()->setOperands(operands);
return success();
}
return failure();
}

OpFoldResult ExtractOp::fold(FoldAdaptor) {
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
Expand All @@ -1999,6 +2041,9 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) {
return val;
if (auto val = foldScalarExtractFromFromElements(*this))
return val;
SmallVector<Value> operands = {getVector()};
if (succeeded(foldConstantOp(*this, operands)))
return getResult();
return OpFoldResult();
}

Expand Down Expand Up @@ -3028,6 +3073,9 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
// (type mismatch).
if (getNumIndices() == 0 && getSourceType() == getType())
return getSource();
SmallVector<Value> operands = {getSource(), getDest()};
if (succeeded(foldConstantOp(*this, operands)))
return getResult();
return {};
}

Expand Down
36 changes: 36 additions & 0 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4115,3 +4115,39 @@ func.func @step_scalable() -> vector<[4]xindex> {
%0 = vector.step : vector<[4]xindex>
return %0 : vector<[4]xindex>
}

// -----

// CHECK-LABEL: @extract_arith_constnt
func.func @extract_arith_constnt() -> i32 {
%v = arith.constant dense<0> : vector<32x1xi32>
%c_0 = arith.constant 0 : index
%elem = vector.extract %v[%c_0, %c_0] : i32 from vector<32x1xi32>
return %elem : i32
}

// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %{{.*}} = llvm.extractelement %[[VAL_2]]{{\[}}%[[VAL_3]] : i64] : vector<1xi32>

// -----

// CHECK-LABEL: @insert_arith_constnt()

func.func @insert_arith_constnt() -> vector<32x1xi32> {
%v = arith.constant dense<0> : vector<32x1xi32>
%c_0 = arith.constant 0 : index
%c_1 = arith.constant 1 : i32
%v_1 = vector.insert %c_1, %v[%c_0, %c_0] : i32 into vector<32x1xi32>
return %v_1 : vector<32x1xi32>
}

// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[VAL_5:.*]] = llvm.insertelement %[[VAL_2]], %[[VAL_3]]{{\[}}%[[VAL_4]] : i64] : vector<1xi32>
// CHECK: %{{.*}} = llvm.insertvalue %[[VAL_5]], %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
3 changes: 1 addition & 2 deletions mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -778,12 +778,11 @@ func.func @warp_constant(%laneid: index) -> (vector<1xf32>) {

// CHECK-PROP-LABEL: func.func @vector_extract_1d(
// CHECK-PROP-DAG: %[[C5_I32:.*]] = arith.constant 5 : i32
// CHECK-PROP-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-PROP: %[[R:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2xf32>) {
// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<64xf32>
// CHECK-PROP: gpu.yield %[[V]] : vector<64xf32>
// CHECK-PROP: }
// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][%[[C1]]] : f32 from vector<2xf32>
// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][1] : f32 from vector<2xf32>
// CHECK-PROP: %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle idx %[[E]], %[[C5_I32]]
// CHECK-PROP: return %[[SHUFFLED]] : f32
func.func @vector_extract_1d(%laneid: index) -> (f32) {
Expand Down
Loading