Skip to content
44 changes: 43 additions & 1 deletion mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1977,7 +1977,43 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
return fromElementsOp.getElements()[flatIndex];
}

OpFoldResult ExtractOp::fold(FoldAdaptor) {
// If the dynamic operands of `extractOp` or `insertOp` is result of
// `constantOp`, then fold it.
template <typename OpType, typename AdaptorType>
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
SmallVectorImpl<Value> &operands) {
auto staticPosition = op.getStaticPosition().vec();
OperandRange dynamicPosition = op.getDynamicPosition();
ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
// If the dynamic operands is empty, it is returned directly.
if (!dynamicPosition.size())
return {};
unsigned index = 0;

// `opChange` is a flag. If it is true, it means to update `op` in place.
bool opChange = false;
for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
if (!ShapedType::isDynamic(staticPosition[i]))
continue;
Attribute positionAttr = dynamicPositionAttr[index];
Value position = dynamicPosition[index++];
if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
staticPosition[i] = attr.getInt();
opChange = true;
continue;
}
operands.push_back(position);
}

if (opChange) {
op.setStaticPosition(staticPosition);
op.getOperation()->setOperands(operands);
return op.getResult();
}
return {};
}

OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
// mismatch).
Expand All @@ -1999,6 +2035,9 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) {
return val;
if (auto val = foldScalarExtractFromFromElements(*this))
return val;
SmallVector<Value> operands = {getVector()};
if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
return val;
return OpFoldResult();
}

Expand Down Expand Up @@ -3028,6 +3067,9 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
// (type mismatch).
if (getNumIndices() == 0 && getSourceType() == getType())
return getSource();
SmallVector<Value> operands = {getSource(), getDest()};
if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
return val;
return {};
}

Expand Down
44 changes: 44 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2979,3 +2979,47 @@ func.func @contiguous_scatter_step(%base: memref<?xf32>,
memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32>
return
}

// -----

// CHECK-LABEL: func @extract_arith_constnt

func.func @extract_arith_constnt() -> i32 {
%c1_i32 = arith.constant 1 : i32
return %c1_i32 : i32
}

// CHECK: %[[VAL_0:.*]] = arith.constant 1 : i32
// CHECK: return %[[VAL_0]] : i32

// -----

// CHECK-LABEL: func @insert_arith_constnt

func.func @insert_arith_constnt() -> vector<4x1xi32> {
%v = arith.constant dense<0> : vector<4x1xi32>
%c_0 = arith.constant 0 : index
%c_1 = arith.constant 1 : i32
%v_1 = vector.insert %c_1, %v[%c_0, %c_0] : i32 into vector<4x1xi32>
return %v_1 : vector<4x1xi32>
}

// CHECK: %[[VAL_0:.*]] = arith.constant dense<{{\[\[}}1], [0], [0], [0]]> : vector<4x1xi32>
// CHECK: return %[[VAL_0]] : vector<4x1xi32>

// -----

// CHECK-LABEL: func @insert_extract_arith_constnt

func.func @insert_extract_arith_constnt() -> i32 {
%v = arith.constant dense<0> : vector<32x1xi32>
%c_0 = arith.constant 0 : index
%c_1 = arith.constant 1 : index
%c_2 = arith.constant 2 : i32
%v_1 = vector.insert %c_2, %v[%c_1, %c_1] : i32 into vector<32x1xi32>
%ret = vector.extract %v_1[%c_1, %c_1] : i32 from vector<32x1xi32>
return %ret : i32
}

// CHECK: %[[VAL_0:.*]] = arith.constant 2 : i32
// CHECK: return %[[VAL_0]] : i32
3 changes: 1 addition & 2 deletions mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -778,12 +778,11 @@ func.func @warp_constant(%laneid: index) -> (vector<1xf32>) {

// CHECK-PROP-LABEL: func.func @vector_extract_1d(
// CHECK-PROP-DAG: %[[C5_I32:.*]] = arith.constant 5 : i32
// CHECK-PROP-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-PROP: %[[R:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2xf32>) {
// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<64xf32>
// CHECK-PROP: gpu.yield %[[V]] : vector<64xf32>
// CHECK-PROP: }
// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][%[[C1]]] : f32 from vector<2xf32>
// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][1] : f32 from vector<2xf32>
// CHECK-PROP: %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle idx %[[E]], %[[C5_I32]]
// CHECK-PROP: return %[[SHUFFLED]] : f32
func.func @vector_extract_1d(%laneid: index) -> (f32) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// RUN: mlir-opt %s -test-lower-to-llvm | \
// RUN: mlir-runner -e entry -entry-point-result=void \
// RUN: -shared-libs=%mlir_c_runner_utils | \
// RUN: FileCheck %s

func.func @entry() {
%v = arith.constant dense<0> : vector<2x2xi32>
%c_0 = arith.constant 0 : index
%c_1 = arith.constant 1 : index
%i32_0 = arith.constant 0 : i32
%i32_1 = arith.constant 1 : i32
%i32_2 = arith.constant 2 : i32
%i32_3 = arith.constant 3 : i32
%v_1 = vector.insert %i32_0, %v[%c_0, %c_0] : i32 into vector<2x2xi32>
%v_2 = vector.insert %i32_1, %v_1[%c_0, %c_1] : i32 into vector<2x2xi32>
%v_3 = vector.insert %i32_2, %v_2[%c_1, %c_0] : i32 into vector<2x2xi32>
%v_4 = vector.insert %i32_3, %v_3[%c_1, %c_1] : i32 into vector<2x2xi32>
// CHECK: ( ( 0, 1 ), ( 2, 3 ) )
vector.print %v_4 : vector<2x2xi32>
%v_5 = arith.constant dense<[[0, 1], [2, 3]]> : vector<2x2xi32>
// CHECK: 0
%i32_4 = vector.extract %v_5[%c_0, %c_0] : i32 from vector<2x2xi32>
// CHECK: 1
%i32_5 = vector.extract %v_5[%c_0, %c_1] : i32 from vector<2x2xi32>
// CHECK: 2
%i32_6 = vector.extract %v_5[%c_1, %c_0] : i32 from vector<2x2xi32>
// CHECK: 3
%i32_7 = vector.extract %v_5[%c_1, %c_1] : i32 from vector<2x2xi32>
vector.print %i32_4 : i32
vector.print %i32_5 : i32
vector.print %i32_6 : i32
vector.print %i32_7 : i32
return
}
Loading