diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td index 299f198e4ab9c..07a4117a37b2c 100644 --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -254,6 +254,17 @@ def ApplyLowerGatherPatternsOp : Op]> { + let description = [{ + Indicates that vector from_elements operations should be unrolled + along the outermost dimension. + }]; + + let assemblyFormat = "attr-dict"; +} + def ApplyLowerScanPatternsOp : Op]> { diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index e03f0dabece52..47f96112a9433 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -303,6 +303,14 @@ void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns); void populateVectorToFromElementsToShuffleTreePatterns( RewritePatternSet &patterns, PatternBenefit benefit = 1); +/// Populate the pattern set with the following patterns: +/// +/// [UnrollFromElements] +/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the +/// outermost dimension. +void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + /// Populate the pattern set with the following patterns: /// /// [ContractionOpToMatmulOpLowering] diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 8bd54cf31b893..ace26990601c8 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -12,6 +12,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" @@ -238,6 +239,22 @@ Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, /// static sizes in `shape`. LogicalResult isValidMaskedInputVector(ArrayRef shape, ArrayRef inputVectorSizes); + +/// Generic utility for unrolling n-D vector operations to (n-1)-D operations. +/// This handles the common pattern of: +/// 1. Check if already 1-D. If so, return failure. +/// 2. Check for scalable dimensions. If so, return failure. +/// 3. Create poison initialized result. +/// 4. Loop through the outermost dimension, execute the UnrollVectorOpFn to +/// create sub vectors. +/// 5. Insert the sub vectors back into the final vector. +/// 6. Replace the original op with the new result. +using UnrollVectorOpFn = + function_ref; + +LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, + UnrollVectorOpFn unrollFn); + } // namespace vector /// Constructs a permutation map of invariant memref indices to vector diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index f9e2a01dbf969..afc3d1b12ac0d 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1891,15 +1891,21 @@ struct VectorFromElementsLowering ConversionPatternRewriter &rewriter) const override { Location loc = fromElementsOp.getLoc(); VectorType vectorType = fromElementsOp.getType(); - // TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>. - // Such ops should be handled in the same way as vector.insert. + // Only support 1-D vectors. Multi-dimensional vectors should have been + // transformed to 1-D vectors by the vector-to-vector transformations before + // this. if (vectorType.getRank() > 1) return rewriter.notifyMatchFailure(fromElementsOp, "rank > 1 vectors are not supported"); Type llvmType = typeConverter->convertType(vectorType); + Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType); - for (auto [idx, val] : llvm::enumerate(adaptor.getElements())) - result = vector::InsertOp::create(rewriter, loc, val, result, idx); + for (auto [idx, val] : llvm::enumerate(adaptor.getElements())) { + auto constIdx = + LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, idx); + result = LLVM::InsertElementOp::create(rewriter, loc, llvmType, result, + val, constIdx); + } rewriter.replaceOp(fromElementsOp, result); return success(); } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index cf108690c3741..9852df6970fdc 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -94,6 +94,7 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorStepLoweringPatterns(patterns); populateVectorRankReducingFMAPattern(patterns); populateVectorGatherLoweringPatterns(patterns); + populateVectorFromElementsLoweringPatterns(patterns); if (armI8MM) { if (armNeon) arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns); diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 2d5cc070558c3..fe066dc04ad55 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -139,6 +139,11 @@ void transform::ApplyLowerGatherPatternsOp::populatePatterns( vector::populateVectorGatherLoweringPatterns(patterns); } +void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + vector::populateVectorFromElementsLoweringPatterns(patterns); +} + void transform::ApplyLowerScanPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::populateVectorScanLoweringPatterns(patterns); diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt index 9e287fc109990..acbf2b746037b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRVectorTransforms LowerVectorBitCast.cpp LowerVectorBroadcast.cpp LowerVectorContract.cpp + LowerVectorFromElements.cpp LowerVectorGather.cpp LowerVectorInterleave.cpp LowerVectorMask.cpp diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp new file mode 100644 index 0000000000000..c22fd54cef46b --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp @@ -0,0 +1,65 @@ +//===- LowerVectorFromElements.cpp - Lower 'vector.from_elements' op -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements target-independent rewrites and utilities to lower the +// 'vector.from_elements' operation. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" + +#define DEBUG_TYPE "lower-vector-from-elements" + +using namespace mlir; + +namespace { + +/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the +/// outermost dimension. For example: +/// ``` +/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32> +/// +/// ==> +/// +/// %0 = ub.poison : vector<2x3xf32> +/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32> +/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32> +/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32> +/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32> +/// ``` +/// +/// When applied exhaustively, this will produce a sequence of 1-d from_elements +/// ops. +struct UnrollFromElements : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::FromElementsOp op, + PatternRewriter &rewriter) const override { + ValueRange allElements = op.getElements(); + + auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc, + VectorType subTy, int64_t index) { + size_t subTyNumElements = subTy.getNumElements(); + assert((index + 1) * subTyNumElements <= allElements.size() && + "out of bounds"); + ValueRange subElements = + allElements.slice(index * subTyNumElements, subTyNumElements); + return vector::FromElementsOp::create(rewriter, loc, subTy, subElements); + }; + + return unrollVectorOp(op, rewriter, unrollFromElementsFn); + } +}; + +} // namespace + +void mlir::vector::populateVectorFromElementsLoweringPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); +} diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp index e062f55f87679..90f21c53246b0 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -54,27 +54,13 @@ struct UnrollGather : OpRewritePattern { LogicalResult matchAndRewrite(vector::GatherOp op, PatternRewriter &rewriter) const override { - VectorType resultTy = op.getType(); - if (resultTy.getRank() < 2) - return rewriter.notifyMatchFailure(op, "already 1-D"); - - // Unrolling doesn't take vscale into account. Pattern is disabled for - // vectors with leading scalable dim(s). - if (resultTy.getScalableDims().front()) - return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim"); - - Location loc = op.getLoc(); Value indexVec = op.getIndexVec(); Value maskVec = op.getMask(); Value passThruVec = op.getPassThru(); - Value result = arith::ConstantOp::create(rewriter, loc, resultTy, - rewriter.getZeroAttr(resultTy)); - - VectorType subTy = VectorType::Builder(resultTy).dropDim(0); - - for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) { - int64_t thisIdx[1] = {i}; + auto unrollGatherFn = [&](PatternRewriter &rewriter, Location loc, + VectorType subTy, int64_t index) { + int64_t thisIdx[1] = {index}; Value indexSubVec = vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx); @@ -82,15 +68,12 @@ struct UnrollGather : OpRewritePattern { vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx); Value passThruSubVec = vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx); - Value subGather = vector::GatherOp::create( - rewriter, loc, subTy, op.getBase(), op.getIndices(), indexSubVec, - maskSubVec, passThruSubVec); - result = - vector::InsertOp::create(rewriter, loc, subGather, result, thisIdx); - } + return vector::GatherOp::create(rewriter, loc, subTy, op.getBase(), + op.getIndices(), indexSubVec, maskSubVec, + passThruSubVec); + }; - rewriter.replaceOp(op, result); - return success(); + return unrollVectorOp(op, rewriter, unrollGatherFn); } }; diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index 6e2fa35e1279a..841e1384e03b3 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -392,3 +392,29 @@ vector::isValidMaskedInputVector(ArrayRef shape, } return success(); } + +LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter, + vector::UnrollVectorOpFn unrollFn) { + assert(op->getNumResults() == 1 && "expected single result"); + assert(isa(op->getResult(0).getType()) && "expected vector type"); + VectorType resultTy = cast(op->getResult(0).getType()); + if (resultTy.getRank() < 2) + return rewriter.notifyMatchFailure(op, "already 1-D"); + + // Unrolling doesn't take vscale into account. Pattern is disabled for + // vectors with leading scalable dim(s). + if (resultTy.getScalableDims().front()) + return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim"); + + Location loc = op->getLoc(); + Value result = ub::PoisonOp::create(rewriter, loc, resultTy); + VectorType subTy = VectorType::Builder(resultTy).dropDim(0); + + for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) { + Value subVector = unrollFn(rewriter, loc, subTy, i); + result = vector::InsertOp::create(rewriter, loc, subVector, result, i); + } + + rewriter.replaceOp(op, result); + return success(); +} diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 72810b5dddaa3..07d335117de01 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1737,3 +1737,40 @@ func.func @step() -> vector<4xindex> { %0 = vector.step : vector<4xindex> return %0 : vector<4xindex> } + + +// ----- + +//===----------------------------------------------------------------------===// +// vector.from_elements +//===----------------------------------------------------------------------===// + +// NOTE: We unroll multi-dimensional from_elements ops with pattern `UnrollFromElements` +// and then convert the 1-D from_elements ops to llvm. + +// CHECK-LABEL: func @from_elements_3d +// CHECK-SAME: %[[ARG_0:.*]]: f32, %[[ARG_1:.*]]: f32, %[[ARG_2:.*]]: f32, %[[ARG_3:.*]]: f32) +// CHECK: %[[UNDEF_RES:.*]] = ub.poison : vector<2x1x2xf32> +// CHECK: %[[UNDEF_RES_LLVM:.*]] = builtin.unrealized_conversion_cast %[[UNDEF_RES]] : vector<2x1x2xf32> to !llvm.array<2 x array<1 x vector<2xf32>>> +// CHECK: %[[UNDEF_VEC_RANK_2:.*]] = ub.poison : vector<1x2xf32> +// CHECK: %[[UNDEF_VEC_RANK_2_LLVM:.*]] = builtin.unrealized_conversion_cast %[[UNDEF_VEC_RANK_2]] : vector<1x2xf32> to !llvm.array<1 x vector<2xf32>> +// CHECK: %[[UNDEF_VEC0:.*]] = llvm.mlir.poison : vector<2xf32> +// CHECK: %[[C0_0:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[VEC0_0:.*]] = llvm.insertelement %[[ARG_0]], %[[UNDEF_VEC0]][%[[C0_0]] : i64] : vector<2xf32> +// CHECK: %[[C1_0:.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[VEC0_1:.*]] = llvm.insertelement %[[ARG_1]], %[[VEC0_0]][%[[C1_0]] : i64] : vector<2xf32> +// CHECK: %[[RES_RANK_2_0:.*]] = llvm.insertvalue %[[VEC0_1]], %[[UNDEF_VEC_RANK_2_LLVM]][0] : !llvm.array<1 x vector<2xf32>> +// CHECK: %[[RES_0:.*]] = llvm.insertvalue %[[RES_RANK_2_0]], %[[UNDEF_RES_LLVM]][0] : !llvm.array<2 x array<1 x vector<2xf32>>> +// CHECK: %[[UNDEF_VEC1:.*]] = llvm.mlir.poison : vector<2xf32> +// CHECK: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[VEC1_0:.*]] = llvm.insertelement %[[ARG_2]], %[[UNDEF_VEC1]][%[[C0_1]] : i64] : vector<2xf32> +// CHECK: %[[C1_1:.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[VEC1_1:.*]] = llvm.insertelement %[[ARG_3]], %[[VEC1_0]][%[[C1_1]] : i64] : vector<2xf32> +// CHECK: %[[RES_RANK_2_1:.*]] = llvm.insertvalue %[[VEC1_1]], %[[UNDEF_VEC_RANK_2_LLVM]][0] : !llvm.array<1 x vector<2xf32>> +// CHECK: %[[RES_1:.*]] = llvm.insertvalue %[[RES_RANK_2_1]], %[[RES_0]][1] : !llvm.array<2 x array<1 x vector<2xf32>>> +// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x array<1 x vector<2xf32>>> to vector<2x1x2xf32> +// CHECK: return %[[CAST]] +func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> { + %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32> + return %0 : vector<2x1x2xf32> +} diff --git a/mlir/test/Dialect/Vector/vector-from-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-from-elements-lowering.mlir new file mode 100644 index 0000000000000..8fac608ed5692 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-from-elements-lowering.mlir @@ -0,0 +1,45 @@ +// RUN: mlir-opt %s -test-unroll-vector-from-elements | FileCheck %s --check-prefix=CHECK-UNROLL + +//===----------------------------------------------------------------------===// +// Test UnrollFromElements. +//===----------------------------------------------------------------------===// + +// CHECK-UNROLL-LABEL: @unroll_from_elements_2d +// CHECK-UNROLL-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32) +// CHECK-UNROLL-NEXT: %[[UNDEF_RES:.*]] = ub.poison : vector<2x2xf32> +// CHECK-UNROLL-NEXT: %[[VEC_0:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32> +// CHECK-UNROLL-NEXT: %[[RES_0:.*]] = vector.insert %[[VEC_0]], %[[UNDEF_RES]] [0] : vector<2xf32> into vector<2x2xf32> +// CHECK-UNROLL-NEXT: %[[VEC_1:.*]] = vector.from_elements %[[ARG2]], %[[ARG3]] : vector<2xf32> +// CHECK-UNROLL-NEXT: %[[RES_1:.*]] = vector.insert %[[VEC_1]], %[[RES_0]] [1] : vector<2xf32> into vector<2x2xf32> +// CHECK-UNROLL-NEXT: return %[[RES_1]] : vector<2x2xf32> +func.func @unroll_from_elements_2d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x2xf32> { + %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32> + return %0 : vector<2x2xf32> +} + +// CHECK-UNROLL-LABEL: @unroll_from_elements_3d +// CHECK-UNROLL-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32) +// CHECK-UNROLL-NEXT: %[[UNDEF_RES:.*]] = ub.poison : vector<2x1x2xf32> +// CHECK-UNROLL-NEXT: %[[UNDEF_RANK_2:.*]] = ub.poison : vector<1x2xf32> +// CHECK-UNROLL-NEXT: %[[VEC_0:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32> +// CHECK-UNROLL-NEXT: %[[RANK_2_0:.*]] = vector.insert %[[VEC_0]], %[[UNDEF_RANK_2]] [0] : vector<2xf32> into vector<1x2xf32> +// CHECK-UNROLL-NEXT: %[[RES_0:.*]] = vector.insert %[[RANK_2_0]], %[[UNDEF_RES]] [0] : vector<1x2xf32> into vector<2x1x2xf32> +// CHECK-UNROLL-NEXT: %[[VEC_1:.*]] = vector.from_elements %[[ARG2]], %[[ARG3]] : vector<2xf32> +// CHECK-UNROLL-NEXT: %[[RANK_2_1:.*]] = vector.insert %[[VEC_1]], %[[UNDEF_RANK_2]] [0] : vector<2xf32> into vector<1x2xf32> +// CHECK-UNROLL-NEXT: %[[RES_1:.*]] = vector.insert %[[RANK_2_1]], %[[RES_0]] [1] : vector<1x2xf32> into vector<2x1x2xf32> +// CHECK-UNROLL-NEXT: return %[[RES_1]] : vector<2x1x2xf32> +func.func @unroll_from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> { + %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32> + return %0 : vector<2x1x2xf32> +} + +// 1-D vector.from_elements should not be unrolled. + +// CHECK-UNROLL-LABEL: @negative_unroll_from_elements_1d +// CHECK-UNROLL-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32) +// CHECK-UNROLL-NEXT: %[[RES:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32> +// CHECK-UNROLL-NEXT: return %[[RES]] : vector<2xf32> +func.func @negative_unroll_from_elements_1d(%arg0: f32, %arg1: f32) -> vector<2xf32> { + %0 = vector.from_elements %arg0, %arg1 : vector<2xf32> + return %0 : vector<2xf32> +} diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir index 5be267c1be984..9c2a508671e06 100644 --- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir @@ -81,7 +81,7 @@ func.func @gather_memref_1d_i32_index(%base: memref, %v: vector<2xi32>, % // CHECK-SAME: %[[PASS:.*]]: vector<2x[3]xf32> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x[3]xf32> +// CHECK: %[[INIT:.*]] = ub.poison : vector<2x[3]xf32> // CHECK: %[[IDXVEC0:.*]] = vector.extract %[[IDXVEC]][0] : vector<[3]xindex> from vector<2x[3]xindex> // CHECK: %[[MASK0:.*]] = vector.extract %[[MASK]][0] : vector<[3]xi1> from vector<2x[3]xi1> // CHECK: %[[PASS0:.*]] = vector.extract %[[PASS]][0] : vector<[3]xf32> from vector<2x[3]xf32> diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index f89c944b5c564..bb1598ee3efe5 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -786,6 +786,28 @@ struct TestVectorGatherLowering } }; +struct TestUnrollVectorFromElements + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnrollVectorFromElements) + + StringRef getArgument() const final { + return "test-unroll-vector-from-elements"; + } + StringRef getDescription() const final { + return "Test unrolling patterns for from_elements ops"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateVectorFromElementsLoweringPatterns(patterns); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); + } +}; + struct TestFoldArithExtensionIntoVectorContractPatterns : public PassWrapper> { @@ -1059,6 +1081,8 @@ void registerTestVectorLowerings() { PassRegistration(); + PassRegistration(); + PassRegistration(); PassRegistration(); diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py index a51f2154d1f7d..5a648fe073315 100644 --- a/mlir/test/python/dialects/transform_vector_ext.py +++ b/mlir/test/python/dialects/transform_vector_ext.py @@ -46,6 +46,8 @@ def non_configurable_patterns(): vector.ApplyLowerOuterProductPatternsOp() # CHECK: transform.apply_patterns.vector.lower_gather vector.ApplyLowerGatherPatternsOp() + # CHECK: transform.apply_patterns.vector.unroll_from_elements + vector.ApplyUnrollFromElementsPatternsOp() # CHECK: transform.apply_patterns.vector.lower_scan vector.ApplyLowerScanPatternsOp() # CHECK: transform.apply_patterns.vector.lower_shape_cast