Skip to content
45 changes: 45 additions & 0 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1989,6 +1989,45 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
return fromElementsOp.getElements()[flatIndex];
}

// If the dynamic indices of `extractOp` or `insertOp` are result of
// `constantOp`, then fold it.
template <typename OpType, typename AdaptorType>
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
SmallVectorImpl<Value> &operands) {
std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
OperandRange dynamicPosition = op.getDynamicPosition();
ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();

// If the dynamic operands is empty, it is returned directly.
if (!dynamicPosition.size())
return {};

// `index` is used to iterate over the `dynamicPosition`.
unsigned index = 0;

// `opChange` is a flag. If it is true, it means to update `op` in place.
bool opChange = false;
for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
if (!ShapedType::isDynamic(staticPosition[i]))
continue;
Attribute positionAttr = dynamicPositionAttr[index];
Value position = dynamicPosition[index++];
if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
staticPosition[i] = attr.getInt();
opChange = true;
continue;
}
operands.push_back(position);
}

if (opChange) {
op.setStaticPosition(staticPosition);
op.getOperation()->setOperands(operands);
return op.getResult();
}
return {};
}

/// Fold an insert or extract operation into an poison value when a poison index
/// is found at any dimension of the static position.
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context,
Expand Down Expand Up @@ -2035,6 +2074,9 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
return val;
if (auto val = foldScalarExtractFromFromElements(*this))
return val;
SmallVector<Value> operands = {getVector()};
if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
return val;
return OpFoldResult();
}

Expand Down Expand Up @@ -3094,6 +3136,9 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
// (type mismatch).
if (getNumIndices() == 0 && getSourceType() == getType())
return getSource();
SmallVector<Value> operands = {getSource(), getDest()};
if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
return val;
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;
Expand Down
42 changes: 42 additions & 0 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,25 @@ func.func @extract_scalar_from_vec_0d_index(%arg0: vector<index>) -> index {

// -----

func.func @extract_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const(%arg : vector<32x1xi32>) -> i32 {
%0 = arith.constant 0 : index
%1 = vector.extract %arg[%0, %0] : i32 from vector<32x1xi32>
return %1 : i32
}

// At compile time, since the indices of extractOp are constants,
// they will be collapsed and folded away; therefore, the lowering works.

// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const
// CHECK-SAME: %[[ARG:.*]]: vector<32x1xi32>) -> i32 {
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
// CHECK: %[[VEC_0:.*]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<32 x vector<1xi32>>
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[RES:.*]] = llvm.extractelement %[[VEC_0]]{{\[}}%[[C0]] : i64] : vector<1xi32>
// CHECK: return %[[RES]] : i32

// -----

//===----------------------------------------------------------------------===//
// vector.insertelement
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -781,6 +800,29 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[1

// -----

func.func @insert_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const(%arg : vector<4x1xi32>) -> vector<4x1xi32> {
%0 = arith.constant 0 : index
%1 = arith.constant 1 : i32
%res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
return %res : vector<4x1xi32>
}

// At compile time, since the indices of insertOp are constants,
// they will be collapsed and folded away; therefore, the lowering works.

// CHECK-LABEL: @insert_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const
// CHECK-SAME: %[[ARG:.*]]: vector<4x1xi32>) -> vector<4x1xi32> {
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x1xi32> to !llvm.array<4 x vector<1xi32>>
// CHECK: %[[C1:.*]] = arith.constant 1 : i32
// CHECK: %[[VEC_0:.*]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<4 x vector<1xi32>>
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[VEC_1:.*]] = llvm.insertelement %[[C1]], %[[VEC_0]]{{\[}}%[[C0]] : i64] : vector<1xi32>
// CHECK: %[[VEC_2:.*]] = llvm.insertvalue %[[VEC_1]], %[[CAST]][0] : !llvm.array<4 x vector<1xi32>>
// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[VEC_2]] : !llvm.array<4 x vector<1xi32>> to vector<4x1xi32>
// CHECK: return %[[RES]] : vector<4x1xi32>

// -----

//===----------------------------------------------------------------------===//
// vector.type_cast
//
Expand Down
26 changes: 26 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3171,3 +3171,29 @@ func.func @contiguous_scatter_step(%base: memref<?xf32>,
memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32>
return
}

// -----

// CHECK-LABEL: @fold_extract_constant_indices
// CHECK-SAME: %[[ARG:.*]]: vector<32x1xi32>) -> i32 {
// CHECK: %[[RES:.*]] = vector.extract %[[ARG]][0, 0] : i32 from vector<32x1xi32>
// CHECK: return %[[RES]] : i32
func.func @fold_extract_constant_indices(%arg : vector<32x1xi32>) -> i32 {
%0 = arith.constant 0 : index
%1 = vector.extract %arg[%0, %0] : i32 from vector<32x1xi32>
return %1 : i32
}

// -----

// CHECK-LABEL: @fold_insert_constant_indices
// CHECK-SAME: %[[ARG:.*]]: vector<4x1xi32>) -> vector<4x1xi32> {
// CHECK: %[[VAL:.*]] = arith.constant 1 : i32
// CHECK: %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] [0, 0] : i32 into vector<4x1xi32>
// CHECK: return %[[RES]] : vector<4x1xi32>
func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi32> {
%0 = arith.constant 0 : index
%1 = arith.constant 1 : i32
%res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
return %res : vector<4x1xi32>
}
3 changes: 1 addition & 2 deletions mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -778,12 +778,11 @@ func.func @warp_constant(%laneid: index) -> (vector<1xf32>) {

// CHECK-PROP-LABEL: func.func @vector_extract_1d(
// CHECK-PROP-DAG: %[[C5_I32:.*]] = arith.constant 5 : i32
// CHECK-PROP-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-PROP: %[[R:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2xf32>) {
// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<64xf32>
// CHECK-PROP: gpu.yield %[[V]] : vector<64xf32>
// CHECK-PROP: }
// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][%[[C1]]] : f32 from vector<2xf32>
// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][1] : f32 from vector<2xf32>
// CHECK-PROP: %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle idx %[[E]], %[[C5_I32]]
// CHECK-PROP: return %[[SHUFFLED]] : f32
func.func @vector_extract_1d(%laneid: index) -> (f32) {
Expand Down