diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index c9d637ce81f93..94efec61a466c 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -35,13 +35,6 @@ using namespace mlir; using namespace mlir::vector; -// Helper to reduce vector type by *all* but one rank at back. -static VectorType reducedVectorTypeBack(VectorType tp) { - assert((tp.getRank() > 1) && "unlowerable vector type"); - return VectorType::get(tp.getShape().take_back(), tp.getElementType(), - tp.getScalableDims().take_back()); -} - // Helper that picks the proper sequence for inserting. static Value insertOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, @@ -1223,7 +1216,6 @@ class VectorInsertOpConversion matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = insertOp->getLoc(); - auto sourceType = insertOp.getSourceType(); auto destVectorType = insertOp.getDestVectorType(); auto llvmResultType = typeConverter->convertType(destVectorType); // Bail if result type cannot be lowered. @@ -1233,53 +1225,81 @@ class VectorInsertOpConversion SmallVector positionVec = getMixedValues( adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter); - // Overwrite entire vector with value. Should be handled by folder, but - // just to be safe. - ArrayRef position(positionVec); - if (position.empty()) { - rewriter.replaceOp(insertOp, adaptor.getSource()); - return success(); - } - - // One-shot insertion of a vector into an array (only requires insertvalue). - if (isa(sourceType)) { - if (insertOp.hasDynamicPosition()) - return failure(); - - Value inserted = rewriter.create( - loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position)); - rewriter.replaceOp(insertOp, inserted); - return success(); + // The logic in this pattern mirrors VectorExtractOpConversion. Refer to + // its explanatory comment about how N-D vectors are converted as nested + // aggregates (llvm.array's) of 1D vectors. + // + // The innermost dimension of the destination vector, when converted to a + // nested aggregate form, will always be a 1D vector. + // + // * If the insertion is happening into the innermost dimension of the + // destination vector: + // - If the destination is a nested aggregate, extract a 1D vector out of + // the aggregate. This can be done using llvm.extractvalue. The + // destination is now guaranteed to be a 1D vector, to which we are + // inserting. + // - Do the insertion into the 1D destination vector, and make the result + // the new source nested aggregate. This can be done using + // llvm.insertelement. + // * Insert the source nested aggregate into the destination nested + // aggregate. + + // Determine if we need to extract/insert a 1D vector out of the aggregate. + bool isNestedAggregate = isa(llvmResultType); + // Determine if we need to insert a scalar into the 1D vector. + bool insertIntoInnermostDim = + static_cast(positionVec.size()) == destVectorType.getRank(); + + ArrayRef positionOf1DVectorWithinAggregate( + positionVec.begin(), + insertIntoInnermostDim ? positionVec.size() - 1 : positionVec.size()); + OpFoldResult positionOfScalarWithin1DVector; + if (destVectorType.getRank() == 0) { + // Since the LLVM type converter converts 0D vectors to 1D vectors, we + // need to create a 0 here as the position into the 1D vector. + Type idxType = typeConverter->convertType(rewriter.getIndexType()); + positionOfScalarWithin1DVector = rewriter.getZeroAttr(idxType); + } else if (insertIntoInnermostDim) { + positionOfScalarWithin1DVector = positionVec.back(); } - // Potential extraction of 1-D vector from array. - Value extracted = adaptor.getDest(); - auto oneDVectorType = destVectorType; - if (position.size() > 1) { - if (insertOp.hasDynamicPosition()) - return failure(); - - oneDVectorType = reducedVectorTypeBack(destVectorType); - extracted = rewriter.create( - loc, extracted, getAsIntegers(position.drop_back())); + // We are going to mutate this 1D vector until it is either the final + // result (in the non-aggregate case) or the value that needs to be + // inserted into the aggregate result. + Value sourceAggregate = adaptor.getSource(); + if (insertIntoInnermostDim) { + // Scalar-into-1D-vector case, so we know we will have to create a + // InsertElementOp. The question is into what destination. + if (isNestedAggregate) { + // Aggregate case: the destination for the InsertElementOp needs to be + // extracted from the aggregate. + if (!llvm::all_of(positionOf1DVectorWithinAggregate, + llvm::IsaPred)) { + // llvm.extractvalue does not support dynamic dimensions. + return failure(); + } + sourceAggregate = rewriter.create( + loc, adaptor.getDest(), + getAsIntegers(positionOf1DVectorWithinAggregate)); + } else { + // No-aggregate case. The destination for the InsertElementOp is just + // the insertOp's destination. + sourceAggregate = adaptor.getDest(); + } + // Insert the scalar into the 1D vector. + sourceAggregate = rewriter.create( + loc, sourceAggregate.getType(), sourceAggregate, adaptor.getSource(), + getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector)); } - // Insertion of an element into a 1-D LLVM vector. - Value inserted = rewriter.create( - loc, typeConverter->convertType(oneDVectorType), extracted, - adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back())); - - // Potential insertion of resulting 1-D vector into array. - if (position.size() > 1) { - if (insertOp.hasDynamicPosition()) - return failure(); - - inserted = rewriter.create( - loc, adaptor.getDest(), inserted, - getAsIntegers(position.drop_back())); + Value result = sourceAggregate; + if (isNestedAggregate) { + result = rewriter.create( + loc, adaptor.getDest(), sourceAggregate, + getAsIntegers(positionOf1DVectorWithinAggregate)); } - rewriter.replaceOp(insertOp, inserted); + rewriter.replaceOp(insertOp, result); return success(); } }; diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir index fa7c030538401..c3f06dd4d5dd1 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir @@ -628,6 +628,16 @@ func.func @insertelement_into_vec_1d_f32_scalable_idx_as_index_scalable(%arg0: f // vector.insert //===----------------------------------------------------------------------===// +func.func @insert_scalar_into_vec_0d(%src: f32, %dst: vector) -> vector { + %0 = vector.insert %src, %dst[] : f32 into vector + return %0 : vector +} + +// CHECK-LABEL: @insert_scalar_into_vec_0d +// CHECK: llvm.insertelement {{.*}} : vector<1xf32> + +// ----- + func.func @insert_scalar_into_vec_1d_f32(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> { %0 = vector.insert %arg0, %arg1[3] : f32 into vector<4xf32> return %0 : vector<4xf32> @@ -780,10 +790,10 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %a return %0 : vector<1x16xf32> } -// Multi-dim vectors are not supported but this test shouldn't crash. - // CHECK-LABEL: @insert_scalar_into_vec_2d_f32_dynamic_idx( -// CHECK: vector.insert +// CHECK: llvm.extractvalue {{.*}} : !llvm.array<1 x vector<16xf32>> +// CHECK: llvm.insertelement {{.*}} : vector<16xf32> +// CHECK: llvm.insertvalue {{.*}} : !llvm.array<1 x vector<16xf32>> // ----- @@ -793,10 +803,26 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[1 return %0 : vector<1x[16]xf32> } -// Multi-dim vectors are not supported but this test shouldn't crash. - // CHECK-LABEL: @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable( -// CHECK: vector.insert +// CHECK: llvm.extractvalue {{.*}} : !llvm.array<1 x vector<[16]xf32>> +// CHECK: llvm.insertelement {{.*}} : vector<[16]xf32> +// CHECK: llvm.insertvalue {{.*}} : !llvm.array<1 x vector<[16]xf32>> + + +// ----- + +func.func @insert_scalar_into_vec_2d_f32_dynamic_idx_fail(%arg0: vector<2x16xf32>, %arg1: f32, %idx: index) + -> vector<2x16xf32> { + %0 = vector.insert %arg1, %arg0[%idx, 0]: f32 into vector<2x16xf32> + return %0 : vector<2x16xf32> +} + +// Currently fails to convert because of the dynamic index in non-innermost +// dimension that converts to a llvm.array, as llvm.extractvalue does not +// support dynamic dimensions + +// CHECK-LABEL: @insert_scalar_into_vec_2d_f32_dynamic_idx_fail +// CHECK: vector.insert // -----