diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h new file mode 100644 index 0000000000000..6fc5bb13f9b07 --- /dev/null +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h @@ -0,0 +1,354 @@ +//===- VectorLinearize.h - Vector linearization patterns --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORLINEARIZE_H +#define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORLINEARIZE_H + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace vector { + +/// Initialize `typeConverter` with source and target materialization logic +/// using shape_casts to/from 1D vectors. +void initializeForVectorLinearize(TypeConverter &typeConverter); + +/// This enum controls the patterns used for linearization of insert, +/// insert_strided_slice, extract, and extract_strided_slice operations. +enum class InsertExtractLinearizePreference { + + /// The lowerings are + /// insert, insert_strided_slice -> 1D shuffle + /// extract, extract_strided_slice -> 1D shuffle + /// + /// Even 1D insert_strided_slice and extract_strided_slice are converted to 1D + /// shuffles. Insert and extract ops on scalar elements are not converted to + /// 1D shuffles. + Shuffle = 0, + + /// The preferred lowerings are + /// insert, insert_strided_slice -> 1D insert_strided_slice + /// extract, extract_strided_slice -> 1D extract_strided_slice + /// + /// When these lowerings are not possible because the slices are not + /// contiguous, 1D shuffles are used. + Strided +}; + +/// Initialize `conversionTarget`, and `patterns` for linearization. Here +/// linearization means converting a single operation with 1+ vector +/// operand/result of rank>1, into a new single operation whose vector operands +/// and results are all of rank<=1. +/// +/// This function initializes `conversionTarget` with the set of operations that +/// are illegal and consequently must be converted to a linearized form. It +/// also populates the set of patterns that can be run to convert illegal +/// operations, and what priority/benefit they have. The patterns and legality +/// rules depend on `preference`, which controls the benefit associated to the +/// patterns based on whether 1D shuffles or 1D strided ops are preferred. +/// +/// Note: the set of legal operations can be extended by a user if, for example, +/// certain rank>1 vectors are considered valid, by adding additional +/// dynamically legal ops to `conversionTarget`. +/// +/// Further note: the choice to use a dialect conversion design for +/// linearization is to make it easy to reuse generic structural type +/// conversions for linearizing scf/cf/func operations +void populateForFullVectorLinearize( + const TypeConverter &, ConversionTarget &conversionTarget, + RewritePatternSet &patterns, + InsertExtractLinearizePreference preference = + InsertExtractLinearizePreference::Shuffle); + +enum class LinearizePattern { + + /// BEFORE + /// %1 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32> + /// + /// AFTER + /// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32> + /// %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32> + LinearizeConstantLike = 0, + + /// BEFORE + /// %2 = math.sin %arg0 : vector<2x2xf32> + /// + /// AFTER + /// %0 = vector.shape_cast %arg0 : vector<2x2xf32> to vector<4xf32> + /// %1 = math.sin %0 : vector<4xf32> + /// %2 = vector.shape_cast %1 : vector<4xf32> to vector<2x2xf32> + LinearizeVectorizable, + + /// BEFORE + /// %2 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16> + /// + /// AFTER + /// %0 = vector.shape_cast %arg0 : vector<4x4xf32> to vector<16xf32> + /// %1 = vector.bitcast %0 : vector<16xf32> to vector<32xf16> + /// %2 = vector.shape_cast %1 : vector<32xf16> to vector<4x8xf16> + LinearizeVectorBitCast, + + /// This pattern currently only supports 2D masks with a unit outer + /// dimension. + /// + /// BEFORE + /// %mask_2d = vector.create_mask %arg0, %arg1 : vector<1x4xi1> + /// + /// AFTER + /// [...] + /// %mask_1d= vector.create_mask %mul : vector<4xi1> + /// %mask_2d = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1> + /// + /// where `%mul` is a function of `%arg0` and `%arg1`. + LinearizeVectorCreateMask, + + /// BEFORE + /// %shuffle_3d = vector.shuffle %v1_3d, %v2_3d [ shuffle_indices ] + /// + /// AFTER + /// %v1_1d = vector.shape_cast %v1_3d : [...] + /// %v2_1d = vector.shape_cast %v2_3d : [...] + /// %shuffle_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ] + /// %shuffle_3d = vector.shape_cast %shuffle_1d : [...] + /// + /// Where `shuffle_indices_1d` are computed by expanding `shuffle_indices`. + LinearizeVectorShuffle, + + /// BEFORE + /// %1 = vector.splat %value : vector<4x4xf32> + /// + /// AFTER + /// %0 = vector.splat %value : vector<16xf32> + /// %1 = vector.shape_cast %0 : vector<16xf32> to vector<4x4xf32> + LinearizeVectorSplat, + + /// Reduce the rank of a vector.extract_strided_slice to the lowest rank + /// possible. For extract_strided_slice ops that slice contiguous elements, + /// the reduced-rank op is 1D, otherwise it is higher dimensional. + /// + /// BEFORE + /// %2 = vector.extract_strided_slice %arg0 { + /// offsets = [1, 0, 1, 0], + /// sizes = [1, 2, 1, 2], + /// strides = [1, 1, 1, 1]} : vector<2x2x2x2xi8> to vector<1x2x1x2xi8> + /// + /// AFTER + /// %0 = vector.shape_cast %arg0 : vector<2x2x2x2xi8> to vector<4x4xi8> + /// %1 = vector.extract_strided_slice %0 { + /// offsets = [2, 2], + /// sizes = [2, 2], + /// strides = [1, 1]} : vector<4x4xi8> to vector<2x2xi8> + /// %2 = vector.shape_cast %1 : vector<2x2xi8> to vector<1x2x1x2xi8> + RankReduceExtractStridedSlice, + + /// Similar to RankReduceExtractStridedSlice, but both the operands have + /// their rank reduced. + /// + /// BEFORE + /// %3 = vector.insert_strided_slice %arg1, %arg0 {[...]} + /// vector<1x2x1x2xi8> into vector<2x2x2x2xi8> + /// + /// AFTER + /// %0 = vector.shape_cast %arg0 : vector<2x2x2x2xi8> to vector<4x4xi8> + /// %1 = vector.shape_cast %arg1 : vector<1x2x1x2xi8> to vector<2x2xi8> + /// %2 = vector.insert_strided_slice %1, %0 {[...]} + /// %3 = vector.shape_cast %2 : vector<4x4xi8> to vector<2x2x2x2xi8> + RankReduceInsertStridedSlice, + + /// BEFORE + /// %extract = vector.extract %src [ position ] + /// + /// AFTER + /// %src_1d = vector.shape_cast %src : [...] + /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices ] + /// %out_nd = vector.shape_cast %out_1d : [...] + /// + /// `shuffle_indices` is computed from `position`. + VectorExtractToRankOneShuffle, + + /// BEFORE + /// %out_nd = vector.extract_strided_slice %source_nd + /// { offsets = [..], strides = [..], sizes = [..] } + /// + /// AFTER + /// %source_1d = vector.shape_cast %source_nd [...] + /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] + /// %out_nd = vector.shape_cast %out_1d [...] + /// + /// `shuffle_indices_1d` is computed using the offsets and sizes of the + /// original vector.extract_strided_slice operation. + VectorExtractStridedSliceToRankOneShuffle, + + /// BEFORE + /// %1 = vector.extract %arg0[1, 2] : vector<2x1xi8> from vector<4x3x2x1xi8> + /// + /// AFTER + /// %0 = vector.shape_cast %arg0 : vector<4x3x2x1xi8> to vector<24xi8> + /// %1 = vector.extract_strided_slice %0 {offsets = [10], sizes = [2] [...] + /// %2 = vector.shape_cast %1 : vector<2xi8> to vector<2x1xi8> + VectorExtractToRankOneStrided, + + /// BEFORE + /// %insert = vector.insert %src %dst [ position ] + /// + /// AFTER + /// %src_1d = vector.shape_cast %src : [...] + /// %dst_1d = vector.shape_cast %dst : [...] + /// %out_1d = vector.shuffle %dst_1d, %src_1d [ shuffle_indices ] + /// %out_nd = vector.shape_cast %out_1d : [...] + /// + /// `shuffle_indices` is computed from `position`. + VectorInsertToRankOneShuffle, + + /// This pattern converts a vector.insert_strided_slice operation into a + /// vector.shuffle operation that has rank-1 (linearized) operands and result. + /// + /// BEFORE + /// %0 = vector.insert_strided_slice %to_store, %into + /// {offsets = [1, 0, 0, 0], strides = [1, 1]} + /// : vector<2x2xi8> into vector<2x1x3x2xi8> + /// AFTER + /// %to_store_1d + /// = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8> + /// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8> + /// %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ] + /// %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8> + /// + /// where shuffle_indices_1d in this case is + /// [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11]. + /// ^^^^^^^^^^^^^^ + /// to_store_1d + VectorInsertStridedSliceToRankOneShuffle, + + /// Similar to VectorExtractToRankOneStrided, but for insert_strided_slice. + VectorInsertToRankOneStrided, + + /// The number of patterns in this enum. + N +}; + +/// This class contains functions to control the set of linearization patterns +/// to include for the conversion, and their priority. +struct VectorLinearizePatterns { + +public: + /// By default all patterns are enabled and have benefit 1. + VectorLinearizePatterns() { + enabled.fill(true); + benefits.fill(PatternBenefit(1)); + } + + /// Add the patterns enabled for the conversion to `patterns`. + void addToPatternSet(const TypeConverter &, + RewritePatternSet &patterns) const; + + VectorLinearizePatterns &enable(LinearizePattern id, bool e = true) { + enabled[static_cast(id)] = e; + return *this; + } + + VectorLinearizePatterns &enableAll(bool e = true) { + enabled.fill(e); + return *this; + } + + bool isEnabled(LinearizePattern id) const { + return enabled[static_cast(id)]; + } + + PatternBenefit getBenefit(LinearizePattern id) const { + return benefits[static_cast(id)]; + } + + VectorLinearizePatterns &setBenefit(LinearizePattern id, + PatternBenefit benefit) { + getBenefitRef(id) = benefit; + return *this; + } + + VectorLinearizePatterns &incrementBenefit(LinearizePattern id, + unsigned inc = 1) { + getBenefitRef(id) = getBenefit(id).getBenefit() + 1; + return *this; + } + +private: + std::array(LinearizePattern::N)> enabled; + std::array(LinearizePattern::N)> + benefits; + + PatternBenefit &getBenefitRef(LinearizePattern id) { + unsigned idInt = static_cast(id); + assert(idInt < static_cast(LinearizePattern::N) && + "invalid linearization pattern id"); + return benefits[idInt]; + } +}; + +/// Consider inserting a vector of shape `small` into a vector of shape `large`, +/// at position `offsets`: this function enumeratates all the indices in `large` +/// that are written to. The enumeration is with row-major ordering. +/// +/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2 +/// positions written to are (1,3) and (1,4), which have linearized indices 8 +/// and 9. So [8,9] is returned. +/// +/// The length of the returned vector is equal to the number of elements in +/// the shape `small` (i.e. the product of dimensions of `small`). +SmallVector getStridedSliceInsertionIndices(ArrayRef small, + ArrayRef large, + ArrayRef offsets); + +/// Return the strided slice with the lowest rank that is equivalent to the +/// strided slice of `small` from `large`, starting at `offsets`. The result is +/// a tuple of three vectors: +/// +/// 0) The shape of the new small vector. +/// 1) The shape of the new large vector. +/// 2) The offsets of the new large vector. +/// +/// Example 1 (contiguous slices can always be represented in 1-D). +/// +/// Input: +/// small = (1, 3, 4) +/// large = (3, 3, 4) +/// offset = (2, 3, 4) +/// +/// Output: +/// small = (12) +/// large = (36) +/// offset = (24) +/// +/// Example 2 (a non-contiguous slice) +/// +/// Input: +/// small = (2, 2, 1, 2) +/// large = (2, 2, 2, 2, 2) +/// offset = (1, 1, 0, 1) +/// +/// +/// Output: +/// small = (4, 2) +/// large = (8, 4) +/// offset = (24, 2) +std::array, 3> +getCollapsedStridedSliceShape(ArrayRef small, ArrayRef large, + ArrayRef offsets); + +std::optional, 3>> +getCollapsedExtractStridedSliceShape(vector::ExtractStridedSliceOp extractOp); + +std::optional, 3>> +getCollapsedInsertStridedSliceShape(vector::InsertStridedSliceOp insertOp); + +} // namespace vector +} // namespace mlir + +#endif diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index 34a94e6ea7051..6954cb7172129 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -406,39 +406,6 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, void populateVectorTransposeNarrowTypeRewritePatterns( RewritePatternSet &patterns, PatternBenefit benefit = 1); -/// Initialize `typeConverter` and `conversionTarget` for vector linearization. -/// -/// Definition: here 'linearization' means converting a single operation with -/// 1+ vector operand/result of rank>1, into a new single operation whose -/// vector operands and results are all of rank<=1. -/// -/// This function registers (1) which operations are legal, and hence should not -/// be linearized, (2) what the converted types are (rank-1 vectors) and how to -/// materialze the conversion (with shape_cast) -/// -/// Note: the set of legal operations can be extended by a user if for example -/// certain rank>1 vectors are considered valid, by adding additional -/// dynamically legal ops to `conversionTarget`. -/// -/// Further note: the choice to use a dialect conversion design for -/// linearization is to make it easy to reuse generic structural type -/// conversions for linearizing scf/cf/func operations -void populateForVectorLinearize(TypeConverter &typeConverter, - ConversionTarget &conversionTarget); - -/// Populates `patterns` for ND vector (N >= 2) linearization. This currently -/// contains patterns for converting ConstantLike, Vectorizable, and -/// vector::BitCast ops. -void populateVectorLinearizeBasePatterns(const TypeConverter &, - const ConversionTarget &, - RewritePatternSet &patterns); - -/// Populates `patterns` for linearizing ND (N >= 2) vector operations -/// to 1D vector shuffle operations. -void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &, - const ConversionTarget &, - RewritePatternSet &patterns); - } // namespace vector } // namespace mlir diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index fcfb401fd9867..359b2ba091967 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5726,13 +5726,21 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { VectorType resultType = getType(); - // No-op shape cast. - if (getSource().getType() == resultType) - return getSource(); - // shape_cast(shape_cast(x)) -> shape_cast(x) - if (auto precedingShapeCast = getSource().getDefiningOp()) { - setOperand(precedingShapeCast.getSource()); + // y = shape_cast(shape_cast(shape_cast(x))) + // -> shape_cast(x) # if x and y different types + // -> x # if x and y same type + // Value newSource = getSource(); + ShapeCastOp parent = *this; + while (auto precedingShapeCast = parent.getSource().getDefiningOp()) { + parent = precedingShapeCast; + } + + if (parent.getSource().getType() == resultType) + return parent.getSource(); + + if (parent != *this){ + setOperand(parent.getSource()); return getResult(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 678a88627ca82..fae452d8e5dc9 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -10,9 +10,10 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Vector/Transforms/VectorLinearize.h" #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Operation.h" @@ -45,6 +46,408 @@ linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter, return rewriter.notifyMatchFailure(loc, "unsupported attr type"); } +/// Convert an array of attributes into a vector of integers. +static FailureOr> intsFromArrayAttr(ArrayAttr attrs) { + if (!attrs) + return failure(); + SmallVector ints; + ints.reserve(attrs.size()); + for (auto attr : attrs) { + if (auto intAttr = dyn_cast(attr)) { + ints.push_back(intAttr.getInt()); + } else { + return failure(); + } + } + return ints; +} + +/// Convert OpFoldResults into a vector of integers, failing when an +/// OpFoldResult is not an Attribute (unless the dimension in `shape` is 1, in +/// which case the offset is 0, irrespective). Ensure that the returned vector +/// is of the same rank as `shape` by appending zeros. +static FailureOr> +getIntegerOffsetsFromFoldResults(ArrayRef offsetFoldResults, + ArrayRef shape) { + assert(shape.size() >= offsetFoldResults.size() && + "offsets assumed not be be higher rank than shape"); + unsigned deltaRank = shape.size() - offsetFoldResults.size(); + SmallVector offsets; + offsets.reserve(offsetFoldResults.size()); + for (auto [offsetFoldResult, dimSize] : + llvm::zip(offsetFoldResults, shape.drop_back(deltaRank))) { + if (dimSize == 1) { + offsets.push_back(0); + } else if (auto offsetAttr = dyn_cast(offsetFoldResult)) { + offsets.push_back(cast(offsetAttr).getInt()); + } else { + return failure(); + } + } + offsets.resize(shape.size(), 0); + return offsets; +} + +/// If `ndIndex` is the index in the n-dimensional array of shape `shape`, get +/// the corresponding index into the flattened array. +static int64_t getIndexInFlattened(ArrayRef ndIndex, + ArrayRef shape) { + assert(ndIndex.size() == shape.size() && + "ndIndex and shape assumed to have the same size"); + int64_t index = 0; + int64_t stride = 1; + for (int i = shape.size() - 1; i >= 0; --i) { + index += ndIndex[i] * stride; + stride *= shape[i]; + } + return index; +} + +/// Return true if `op` is an insert, extract, insert_strided_slice, or +/// extract_strided_slice operation that operates on scalable vectors. +/// Otherwise return false. +static bool isScalableExtractOrInsertOrStrided(Operation *op) { + return TypeSwitch(op) + .Case( + [&](vector::ExtractStridedSliceOp extractOp) { + return extractOp.getType().isScalable(); + }) + .Case( + [&](vector::InsertStridedSliceOp insertOp) { + return insertOp.getType().isScalable(); + }) + .Case([&](vector::InsertOp insertOp) { + return insertOp.getType().isScalable(); + }) + .Case([&](vector::ExtractOp extractOp) { + return extractOp.getSourceVectorType().isScalable(); + }) + .Default([&](auto) { return false; }); +} + +SmallVector +vector::getStridedSliceInsertionIndices(ArrayRef small, + ArrayRef large, + ArrayRef offsets) { + + // Example of alignment between, `large`, `small` and `offsets`: + // large = 4, 5, 6, 7, 8 + // small = 1, 6, 7, 8 + // offsets = 2, 3, 0 + // + // `offsets` has implicit trailing 0s, `small` has implicit leading 1s. + assert((large.size() >= small.size()) && + "rank of 'large' cannot be lower than rank of 'small'"); + assert((large.size() >= offsets.size()) && + "rank of 'large' cannot be lower than the number of offsets"); + unsigned delta = large.size() - small.size(); + unsigned nOffsets = offsets.size(); + auto getSmall = [&](int64_t i) -> int64_t { + return i >= delta ? small[i - delta] : 1; + }; + auto getOffset = [&](int64_t i) -> int64_t { + return i < nOffsets ? offsets[i] : 0; + }; + + // Using 2 vectors of indices, at each iteration populate the updated set of + // indices based on the old set of indices, and the size of the small vector + // in the current iteration. + SmallVector indices{0}; + int64_t stride = 1; + for (int i = large.size() - 1; i >= 0; --i) { + int64_t currentSize = indices.size(); + int64_t smallSize = getSmall(i); + int64_t nextSize = currentSize * smallSize; + SmallVector nextIndices(nextSize); + int64_t *base = nextIndices.begin(); + int64_t offset = getOffset(i) * stride; + for (int j = 0; j < smallSize; ++j) { + for (int k = 0; k < currentSize; ++k) { + base[k] = indices[k] + offset; + } + offset += stride; + base += currentSize; + } + stride *= large[i]; + indices = std::move(nextIndices); + } + return indices; +} + +void vector::initializeForVectorLinearize(TypeConverter &typeConverter) { + + auto convertType = [](Type type) -> std::optional { + VectorType vectorType = dyn_cast(type); + + if (!vectorType || !vector::isLinearizableVector(vectorType)) + return type; + + VectorType linearizedType = + VectorType::get(vectorType.getNumElements(), + vectorType.getElementType(), vectorType.isScalable()); + + return linearizedType; + }; + typeConverter.addConversion(convertType); + + auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1) { + return nullptr; + } + Value value = inputs.front(); + if (!isa(type) || !isa(value.getType())) { + return nullptr; + } + return builder.create(loc, type, value); + }; + typeConverter.addSourceMaterialization(materializeCast); + typeConverter.addTargetMaterialization(materializeCast); +} + +void vector::populateForFullVectorLinearize( + const TypeConverter &typeConverter, ConversionTarget &target, + RewritePatternSet &patterns, InsertExtractLinearizePreference preference) { + + target.markUnknownOpDynamicallyLegal( + [=](Operation *op) -> std::optional { + // Only ops that are in the vector dialect, are ConstantLike, or + // are Vectorizable might be linearized currently. + StringLiteral vectorDialect = + vector::VectorDialect::getDialectNamespace(); + StringRef opDialect = op->getDialect()->getNamespace(); + bool supported = (opDialect == vectorDialect) || + op->hasTrait() || + op->hasTrait(); + if (!supported) + return true; + + // As type legalization is done with vector.shape_cast, shape_cast + // itself cannot be linearized (doing so would create new shape_casts to + // linearize ad infinitum). + if (isa(op)) + return true; + + // The operations extract_strided_slice, extract, insert_strided_slice, + // and insert are linearized to a rank-1 operations that do not fully + // support scalable vectors, so it is not generally possible to + // linearize these ops if they operate on scalable vectors. + if (isScalableExtractOrInsertOrStrided(op)) + return true; + + // This will return true if, for all operand and result types `t`, + // convertType(t) = t. This is true if there are no rank>=2 vectors. + return typeConverter.isLegal(op); + }); + + VectorLinearizePatterns linearizePatterns; + + if (preference == InsertExtractLinearizePreference::Shuffle) { + // Mark extract_strided_slice, insert_strided_slice, extract with source + // rank > 1, and insert with result rank > 1 as illegal, as they must be + // converted to shuffle or rank-1 extract/insert. + // + // Note that the order of the calls to `markUnknownOpDynamicallyLegal` + // is important: the legality rule added here takes precedence over the + // generic one preceding it which marked these ops as legal. + target.markUnknownOpDynamicallyLegal( + [](Operation *op) -> std::optional { + bool isStrided = + isa( + op); + + bool isHighRankExtractOrInsert = [&]() { + if (auto extractOp = dyn_cast(op)) { + return extractOp.getSourceVectorType().getRank() > 1; + } + if (auto insertOp = dyn_cast(op)) { + return insertOp.getType().getRank() > 1; + } + return false; + }(); + + bool isScalable = isScalableExtractOrInsertOrStrided(op); + + if ((isStrided || isHighRankExtractOrInsert) && !isScalable) { + return false; + } + return std::nullopt; + }); + + // Ensure that the benefit of patterns targetting shuffle is higher than + // the benefit of patterns targeting rank-1 strided slice operations. This + // will ensure that patterns for converting to rank-1 shuffle are run first. + linearizePatterns + .incrementBenefit( + LinearizePattern::VectorExtractStridedSliceToRankOneShuffle) + .incrementBenefit( + LinearizePattern::VectorInsertStridedSliceToRankOneShuffle) + .incrementBenefit(LinearizePattern::VectorExtractToRankOneShuffle) + .incrementBenefit(LinearizePattern::VectorInsertToRankOneShuffle); + + } else if (preference == InsertExtractLinearizePreference::Strided) { + linearizePatterns + .incrementBenefit(LinearizePattern::RankReduceInsertStridedSlice) + .incrementBenefit(LinearizePattern::RankReduceExtractStridedSlice) + .incrementBenefit(LinearizePattern::VectorInsertToRankOneStrided) + .incrementBenefit(LinearizePattern::VectorExtractToRankOneStrided); + } else { + assert(false && "unsupported InsertExtractLinearizePreference"); + } + linearizePatterns.addToPatternSet(typeConverter, patterns); +} + +/// Get the lowest rank shapes and offsets which represent the same strided +/// slice as the strided slice described by `small`, `large`, and `offsets`. +/// +/// Example +/// +/// %0 = vector.extract_strided_slice %1 +/// {ofsets = [0, 0, 0], sizes = [2, 2, 2], strides = [1, 1, 1]} : +/// vector<4x2x4xf32> to vector<2x2x2xf32> +/// +/// is equivalent to +/// +/// [...rank reducing shape casts...] +/// %0 = vector.extract_strided_slice %1 +/// {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : +/// vector<8x4xf32> to vector<4x2xf32> +/// [...rank increasing shape cast...] +/// +/// So the output for +/// (small, large, offsets = [2, 2, 2], [4, 2, 4], [0, 0, 0]) is +/// (small, large, offsets = [4, 2], [8, 4], [0, 0]) +std::array, 3> +vector::getCollapsedStridedSliceShape(ArrayRef small, + ArrayRef large, + ArrayRef offsets) { + + // The total number of elements in the small (large, respectively) vector. + int64_t tSmall = std::accumulate(small.begin(), small.end(), 1, + std::multiplies()); + int64_t tLarge = std::accumulate(large.begin(), large.end(), 1, + std::multiplies()); + assert(tLarge >= tSmall && + "total number of elements in 'small' is larger than in 'large'"); + assert(large.size() >= small.size() && + "rank of 'small' is larger than rank of 'large'"); + assert(offsets.size() <= large.size() && + "rank of large is less than the number of offsets"); + + int64_t nOffsets = offsets.size(); + auto getOffset = [&](int64_t i) -> int64_t { + return i < nOffsets ? offsets[i] : 0; + }; + + unsigned delta = large.size() - small.size(); + + // The cumulative (product of dimensions) number of elements from the back + // currently visited in the small (large, respectively) vector. + int64_t nSmall = 1; + int64_t nLarge = 1; + + // The cumulative number (product of dimensions) of elements from the back + // currently visited within the current collapse group in the small (large, + // respectively) vector. + int64_t cSmall = 1; + int64_t cLarge = 1; + + SmallVector newSmall, newLarge, newOffsets; + if (large.size() == 0) + return {newSmall, newLarge, newOffsets}; + + // The offset assigned to the current collapse group. + int64_t cOff = 0; + + unsigned index = large.size() - 1; + while (nLarge < tLarge) { + assert(cSmall <= nSmall && nSmall <= tSmall && // + cLarge <= nLarge && nLarge <= tLarge && + "confusion in element accumulation"); + cOff += getOffset(index) * cLarge; + if (nSmall < tSmall) { + cSmall *= small[index - delta]; + nSmall *= small[index - delta]; + } + cLarge *= large[index]; + nLarge *= large[index]; + if ((nSmall < tSmall) && (large[index] != small[index - delta])) { + newSmall.push_back(cSmall); + newLarge.push_back(cLarge); + newOffsets.push_back(cOff); + cSmall = 1; + cLarge = 1; + cOff = 0; + } + --index; + } + newSmall.push_back(cSmall); + newLarge.push_back(cLarge); + newOffsets.push_back(cOff); + std::reverse(newSmall.begin(), newSmall.end()); + std::reverse(newLarge.begin(), newLarge.end()); + std::reverse(newOffsets.begin(), newOffsets.end()); + return {newSmall, newLarge, newOffsets}; +} + +// returns small, large, offsets. +std::optional, 3>> +vector::getCollapsedExtractStridedSliceShape( + vector::ExtractStridedSliceOp extractOp) { + + if (extractOp.hasNonUnitStrides()) + return std::nullopt; + + ArrayRef outShape = extractOp.getType().getShape(); + ArrayRef inShape = extractOp.getSourceVectorType().getShape(); + + auto maybeIntOffsets = intsFromArrayAttr(extractOp.getOffsets()); + if (failed(maybeIntOffsets)) + return std::nullopt; + + SmallVector offsets = std::move(maybeIntOffsets.value()); + const auto &[collapsedOutShape, collapsedInShape, collapsedOffsets] = + getCollapsedStridedSliceShape(outShape, inShape, offsets); + + bool unchanged = (collapsedInShape.size() == inShape.size()) && + (collapsedOutShape.size() == outShape.size()); + + if (unchanged) + return std::nullopt; + + return std::array, 3>{ + collapsedOutShape, collapsedInShape, collapsedOffsets}; +} + +// returns small, large, offsets. +std::optional, 3>> +vector::getCollapsedInsertStridedSliceShape( + vector::InsertStridedSliceOp insertOp) { + + if (insertOp.hasNonUnitStrides()) + return std::nullopt; + + ArrayRef outShape = insertOp.getType().getShape(); + ArrayRef inShape = insertOp.getSourceVectorType().getShape(); + + auto maybeIntOffsets = intsFromArrayAttr(insertOp.getOffsets()); + if (failed(maybeIntOffsets)) + return std::nullopt; + + SmallVector offsets = std::move(maybeIntOffsets.value()); + const auto &[collapsedInShape, collapsedOutShape, collapsedOffsets] = + getCollapsedStridedSliceShape(inShape, outShape, offsets); + + bool unchanged = (collapsedInShape.size() == inShape.size()) && + (collapsedOutShape.size() == outShape.size()); + + if (unchanged) + return std::nullopt; + + return std::array, 3>{ + collapsedInShape, collapsedOutShape, collapsedOffsets}; +} + namespace { struct LinearizeConstantLike final @@ -52,7 +455,7 @@ struct LinearizeConstantLike final using OpTraitConversionPattern::OpTraitConversionPattern; LinearizeConstantLike(const TypeConverter &typeConverter, - MLIRContext *context, PatternBenefit benefit = 1) + MLIRContext *context, PatternBenefit benefit) : OpTraitConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -94,7 +497,7 @@ struct LinearizeVectorizable final public: LinearizeVectorizable(const TypeConverter &typeConverter, - MLIRContext *context, PatternBenefit benefit = 1) + MLIRContext *context, PatternBenefit benefit) : OpTraitConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -109,90 +512,6 @@ struct LinearizeVectorizable final } }; -template -static bool stridesAllOne(TOp op) { - static_assert( - std::is_same_v || - std::is_same_v, - "expected vector.extract_strided_slice or vector.insert_strided_slice"); - ArrayAttr strides = op.getStrides(); - return llvm::all_of(strides, isOneInteger); -} - -/// Convert an array of attributes into a vector of integers, if possible. -static FailureOr> intsFromArrayAttr(ArrayAttr attrs) { - if (!attrs) - return failure(); - SmallVector ints; - ints.reserve(attrs.size()); - for (auto attr : attrs) { - if (auto intAttr = dyn_cast(attr)) { - ints.push_back(intAttr.getInt()); - } else { - return failure(); - } - } - return ints; -} - -/// Consider inserting a vector of shape `small` into a vector of shape `large`, -/// at position `offsets`: this function enumeratates all the indices in `large` -/// that are written to. The enumeration is with row-major ordering. -/// -/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2 -/// positions written to are (1,3) and (1,4), which have linearized indices 8 -/// and 9. So [8,9] is returned. -/// -/// The length of the returned vector is equal to the number of elements in -/// the shape `small` (i.e. the product of dimensions of `small`). -SmallVector static getStridedSliceInsertionIndices( - ArrayRef small, ArrayRef large, - ArrayRef offsets) { - - // Example of alignment between, `large`, `small` and `offsets`: - // large = 4, 5, 6, 7, 8 - // small = 1, 6, 7, 8 - // offsets = 2, 3, 0 - // - // `offsets` has implicit trailing 0s, `small` has implicit leading 1s. - assert((large.size() >= small.size()) && - "rank of 'large' cannot be lower than rank of 'small'"); - assert((large.size() >= offsets.size()) && - "rank of 'large' cannot be lower than the number of offsets"); - unsigned delta = large.size() - small.size(); - unsigned nOffsets = offsets.size(); - auto getSmall = [&](int64_t i) -> int64_t { - return i >= delta ? small[i - delta] : 1; - }; - auto getOffset = [&](int64_t i) -> int64_t { - return i < nOffsets ? offsets[i] : 0; - }; - - // Using 2 vectors of indices, at each iteration populate the updated set of - // indices based on the old set of indices, and the size of the small vector - // in the current iteration. - SmallVector indices{0}; - int64_t stride = 1; - for (int i = large.size() - 1; i >= 0; --i) { - int64_t currentSize = indices.size(); - int64_t smallSize = getSmall(i); - int64_t nextSize = currentSize * smallSize; - SmallVector nextIndices(nextSize); - int64_t *base = nextIndices.begin(); - int64_t offset = getOffset(i) * stride; - for (int j = 0; j < smallSize; ++j) { - for (int k = 0; k < currentSize; ++k) { - base[k] = indices[k] + offset; - } - offset += stride; - base += currentSize; - } - stride *= large[i]; - indices = std::move(nextIndices); - } - return indices; -} - /// This pattern converts a vector.extract_strided_slice operation into a /// vector.shuffle operation that has a rank-1 (linearized) operand and result. /// @@ -212,12 +531,12 @@ SmallVector static getStridedSliceInsertionIndices( /// /// `shuffle_indices_1d` is computed using the offsets and sizes of the original /// vector.extract_strided_slice operation. -struct LinearizeVectorExtractStridedSlice final - : public mlir::OpConversionPattern { +struct VectorExtractStridedSliceToRankOneShuffle final + : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter, - MLIRContext *context, - PatternBenefit benefit = 1) + VectorExtractStridedSliceToRankOneShuffle(const TypeConverter &typeConverter, + MLIRContext *context, + PatternBenefit benefit) : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult @@ -231,7 +550,7 @@ struct LinearizeVectorExtractStridedSlice final // Expect a legalization failure if the strides are not all 1 (if ever the // verifier for extract_strided_slice allows non-1 strides). - if (!stridesAllOne(extractStridedSliceOp)) { + if (extractStridedSliceOp.hasNonUnitStrides()) { return rewriter.notifyMatchFailure( extractStridedSliceOp, "extract_strided_slice with strides != 1 not supported"); @@ -249,7 +568,7 @@ struct LinearizeVectorExtractStridedSlice final ArrayRef outputShape = extractStridedSliceOp.getType().getShape(); - SmallVector indices = getStridedSliceInsertionIndices( + SmallVector indices = vector::getStridedSliceInsertionIndices( outputShape, inputShape, offsets.value()); Value srcVector = adaptor.getVector(); @@ -259,36 +578,24 @@ struct LinearizeVectorExtractStridedSlice final } }; -/// This pattern converts a vector.insert_strided_slice operation into a -/// vector.shuffle operation that has rank-1 (linearized) operands and result. -/// -/// For example, the following: -/// ``` -/// %0 = vector.insert_strided_slice %to_store, %into -/// {offsets = [1, 0, 0, 0], strides = [1, 1]} -/// : vector<2x2xi8> into vector<2x1x3x2xi8> -/// ``` -/// -/// is converted to -/// ``` -/// %to_store_1d -/// = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8> -/// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8> -/// %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ] -/// %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8> -/// ``` -/// -/// where shuffle_indices_1d in this case is -/// [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11]. -/// ^^^^^^^^^^^^^^ -/// to_store_1d -/// -struct LinearizeVectorInsertStridedSlice final - : public mlir::OpConversionPattern { +static Value asRankOne(ConversionPatternRewriter &rewriter, Value v) { + auto vType = dyn_cast(v.getType()); + assert(vType && "expected vector type"); + assert(vType.getRank() <= 1 && "expected rank-0 or rank-1 type"); + if (vType.getRank() == 1) + return v; + // Convert rank-0 vector to rank-1 vector. + v = rewriter.create( + v.getLoc(), VectorType::get({1}, vType.getElementType()), v); + return v; +} + +struct VectorInsertStridedSliceToRankOneShuffle final + : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter, - MLIRContext *context, - PatternBenefit benefit = 1) + VectorInsertStridedSliceToRankOneShuffle(const TypeConverter &typeConverter, + MLIRContext *context, + PatternBenefit benefit) : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult @@ -298,7 +605,7 @@ struct LinearizeVectorInsertStridedSlice final // Expect a legalization failure if the strides are not all 1 (if ever the // verifier for insert_strided_slice allows non-1 strides). - if (!stridesAllOne(insertStridedSliceOp)) { + if (insertStridedSliceOp.hasNonUnitStrides()) { return rewriter.notifyMatchFailure( insertStridedSliceOp, "insert_strided_slice with strides != 1 not supported"); @@ -317,7 +624,7 @@ struct LinearizeVectorInsertStridedSlice final return rewriter.notifyMatchFailure(insertStridedSliceOp, "failed to get integer offsets"); } - SmallVector sliceIndices = getStridedSliceInsertionIndices( + SmallVector sliceIndices = vector::getStridedSliceInsertionIndices( inputShape, outputShape, offsets.value()); SmallVector indices(nOutputElements); @@ -326,7 +633,7 @@ struct LinearizeVectorInsertStridedSlice final indices[sliceIndex] = index + nOutputElements; } - Value flatToStore = adaptor.getValueToStore(); + Value flatToStore = asRankOne(rewriter, adaptor.getValueToStore()); Value flatDest = adaptor.getDest(); rewriter.replaceOpWithNewOp(insertStridedSliceOp, flatDest.getType(), flatDest, @@ -350,7 +657,7 @@ struct LinearizeVectorShuffle final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LinearizeVectorShuffle(const TypeConverter &typeConverter, - MLIRContext *context, PatternBenefit benefit = 1) + MLIRContext *context, PatternBenefit benefit) : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult @@ -360,8 +667,8 @@ struct LinearizeVectorShuffle final getTypeConverter()->convertType(shuffleOp.getType()); assert(dstType && "vector type destination expected."); - Value vec1 = adaptor.getV1(); - Value vec2 = adaptor.getV2(); + Value vec1 = asRankOne(rewriter, adaptor.getV1()); + Value vec2 = asRankOne(rewriter, adaptor.getV2()); int shuffleSliceLen = 1; int rank = shuffleOp.getV1().getType().getRank(); @@ -404,11 +711,11 @@ struct LinearizeVectorShuffle final /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] /// %out_nd = vector.shape_cast %out_1d /// `shuffle_indices_1d` is computed using the position of the original extract. -struct LinearizeVectorExtract final +struct VectorExtractToRankOneShuffle final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LinearizeVectorExtract(const TypeConverter &typeConverter, - MLIRContext *context, PatternBenefit benefit = 1) + VectorExtractToRankOneShuffle(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit) : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, @@ -436,11 +743,120 @@ struct LinearizeVectorExtract final linearizedOffset += offsets[i] * size; } + Value v0 = asRankOne(rewriter, adaptor.getVector()); llvm::SmallVector indices(size); std::iota(indices.begin(), indices.end(), linearizedOffset); - rewriter.replaceOpWithNewOp( - extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices); + rewriter.replaceOpWithNewOp(extractOp, dstTy, v0, v0, + indices); + + return success(); + } +}; + +/// Convert a vector.extract op with input rank > 1, to an operation with input +/// of rank 1 and output of rank <= 1. Two lowering cases: +/// +/// 1) If the result of the vector.extract is a scalar, convert it to a +/// vector.extract on a rank-1 input which still outputs a scalar. +/// +/// 2) Otherwise, convert to an extract_strided_slice op on a vector of rank-1. +struct VectorExtractToRankOneStrided final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + VectorExtractToRankOneStrided(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + // TypedValue input = extractOp.getVector(); + VectorType inType = extractOp.getVector().getType(); + if (inType.getRank() == 1) + return failure(); + + SmallVector offsets = extractOp.getMixedPosition(); + auto maybeIntOffsets = + getIntegerOffsetsFromFoldResults(offsets, inType.getShape()); + if (failed(maybeIntOffsets)) { + return failure(); + } + const auto &intOffsets = maybeIntOffsets.value(); + int64_t globalOffset = getIndexInFlattened(intOffsets, inType.getShape()); + + Location loc = extractOp.getLoc(); + + Type outType = extractOp.getType(); + + // Case 1 described above: + if (outType.isIntOrIndexOrFloat()) { + Value flattened = rewriter.create( + loc, adaptor.getVector(), SmallVector{globalOffset}); + rewriter.replaceOp(extractOp, flattened); + return success(); + } + + VectorType vOutType = dyn_cast(outType); + assert(vOutType && "expected vector type for output"); + + auto numberElementsOut = vOutType.getNumElements(); + auto strided = rewriter.create( + loc, adaptor.getVector(), SmallVector{globalOffset}, + SmallVector{numberElementsOut}, SmallVector{1}); + + rewriter.replaceOp(extractOp, strided); + return success(); + } +}; + +/// Convert vector.insert where the destination is rank > 1. Two cases: +/// +/// 1) If the source to insert is a scalar, convert to a vector.insert op +/// where the destination is rank-1. +/// +/// 2) Otherwise, convert to a vector.insert_strided_slice op into a vector of +/// rank-1. +struct VectorInsertToRankOneStrided final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + VectorInsertToRankOneStrided(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + VectorType largeType = insertOp.getDest().getType(); + Type smallType = insertOp.getValueToStoreType(); + SmallVector positions = insertOp.getMixedPosition(); + auto maybeIntOffsets = + getIntegerOffsetsFromFoldResults(positions, largeType.getShape()); + if (failed(maybeIntOffsets)) { + return failure(); + } + const auto &intOffsets = maybeIntOffsets.value(); + int64_t globalOffset = + getIndexInFlattened(intOffsets, largeType.getShape()); + + Location loc = insertOp.getLoc(); + + // case 1 + if (smallType.isSignlessIntOrFloat()) { + auto flatOut = rewriter.create( + loc, adaptor.getValueToStore(), adaptor.getDest(), + SmallVector{globalOffset}); + rewriter.replaceOp(insertOp, flatOut); + return success(); + } + // case 2 + Value v0 = asRankOne(rewriter, adaptor.getValueToStore()); + auto flatOut = rewriter.create( + insertOp.getLoc(), v0, adaptor.getDest(), + SmallVector{globalOffset}, SmallVector{1}); + rewriter.replaceOp(insertOp, flatOut); return success(); } }; @@ -455,11 +871,11 @@ struct LinearizeVectorExtract final /// %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d /// ] %out_nd = vector.shape_cast %out_1d /// `shuffle_indices_1d` is computed using the position of the original insert. -struct LinearizeVectorInsert final +struct VectorInsertToRankOneShuffle final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LinearizeVectorInsert(const TypeConverter &typeConverter, - MLIRContext *context, PatternBenefit benefit = 1) + VectorInsertToRankOneShuffle(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, @@ -508,7 +924,8 @@ struct LinearizeVectorInsert final // [offset+srcNumElements, end) rewriter.replaceOpWithNewOp( - insertOp, dstTy, adaptor.getDest(), adaptor.getValueToStore(), indices); + insertOp, dstTy, adaptor.getDest(), + asRankOne(rewriter, adaptor.getValueToStore()), indices); return success(); } @@ -526,7 +943,7 @@ struct LinearizeVectorBitCast final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LinearizeVectorBitCast(const TypeConverter &typeConverter, - MLIRContext *context, PatternBenefit benefit = 1) + MLIRContext *context, PatternBenefit benefit) : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor, @@ -535,7 +952,7 @@ struct LinearizeVectorBitCast final assert(resType && "expected 1-D vector type"); rewriter.replaceOpWithNewOp(castOp, resType, adaptor.getSource()); - return mlir::success(); + return success(); } }; @@ -550,7 +967,7 @@ struct LinearizeVectorSplat final using OpConversionPattern::OpConversionPattern; LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context, - PatternBenefit benefit = 1) + PatternBenefit benefit) : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult @@ -581,7 +998,7 @@ struct LinearizeVectorCreateMask final using OpConversionPattern::OpConversionPattern; LinearizeVectorCreateMask(const TypeConverter &typeConverter, - MLIRContext *context, PatternBenefit benefit = 1) + MLIRContext *context, PatternBenefit benefit) : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult @@ -607,17 +1024,16 @@ struct LinearizeVectorCreateMask final // The result of the comparison is then multiplied with // the second operand of create_mask to get the 1D mask. auto firstOperand = adaptor.getOperands().front(); - auto zero = rewriter.create(loc, 0); - auto isNonZero = rewriter.createOrFold( - loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero); - auto isNonZeroIndex = rewriter.createOrFold( + auto zero = rewriter.create(loc, 0); + auto isNonZero = rewriter.createOrFold( + loc, arith::CmpIPredicate::sgt, firstOperand, zero); + auto isNonZeroIndex = rewriter.createOrFold( loc, rewriter.getIndexType(), isNonZero); auto secondOperand = adaptor.getOperands().back(); - auto maskSize = rewriter.createOrFold( + auto maskSize = rewriter.createOrFold( loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand); - auto newMask = - rewriter.create(loc, dstTy, maskSize); + auto newMask = rewriter.create(loc, dstTy, maskSize); rewriter.replaceOp(createMaskOp, newMask); return success(); } @@ -625,104 +1041,179 @@ struct LinearizeVectorCreateMask final } // namespace -/// This method defines the set of operations that are linearizable, and hence -/// that are considered illegal for the conversion target. -static bool isLinearizable(Operation *op) { +/// This pattern converts a vector.extract_strided_slice into a new +/// vector.extract_strided_slice where the operand and result of the new +/// vector.extract_strided_slice have ranks that are as low as possible. +/// +/// If the original vector.extract_strided_slice is a contiguous slice of +/// a vector, then the new vector.extract_strided_slice will have rank-1 +/// operand and result. Otherwise additional dimensions will remain in the +/// new operand and result. +struct RankReduceExtractStridedSlice final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - // Only ops that are in the vector dialect, are ConstantLike, or - // are Vectorizable might be linearized currently. - StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace(); - StringRef opDialect = op->getDialect()->getNamespace(); - bool supported = (opDialect == vectorDialect) || - op->hasTrait() || - op->hasTrait(); - if (!supported) - return false; + RankReduceExtractStridedSlice(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit) + : OpConversionPattern(typeConverter, context, benefit) {} - return TypeSwitch(op) - // As type legalization is done with vector.shape_cast, shape_cast - // itself cannot be linearized (will create new shape_casts to linearize - // ad infinitum). - .Case([&](auto) { return false; }) - // The operations - // - vector.extract_strided_slice - // - vector.extract - // - vector.insert_strided_slice - // - vector.insert - // are linearized to a rank-1 vector.shuffle by the current patterns. - // vector.shuffle only supports fixed size vectors, so it is impossible to - // use this approach to linearize these ops if they operate on scalable - // vectors. - .Case( - [&](vector::ExtractStridedSliceOp extractOp) { - return !extractOp.getType().isScalable(); - }) - .Case( - [&](vector::InsertStridedSliceOp insertOp) { - return !insertOp.getType().isScalable(); - }) - .Case([&](vector::InsertOp insertOp) { - return !insertOp.getType().isScalable(); - }) - .Case([&](vector::ExtractOp extractOp) { - return !extractOp.getSourceVectorType().isScalable(); - }) - .Default([&](auto) { return true; }); -} + LogicalResult + matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { -void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter, - ConversionTarget &target) { + auto maybeCollapsed = getCollapsedExtractStridedSliceShape(extractOp); + if (!maybeCollapsed.has_value()) + return failure(); - auto convertType = [](Type type) -> std::optional { - VectorType vectorType = dyn_cast(type); - if (!vectorType || !isLinearizableVector(vectorType)) - return type; + const auto &[collapsedOutShape, collapsedInShape, collapsedOffsets] = + maybeCollapsed.value(); - VectorType linearizedType = - VectorType::get(vectorType.getNumElements(), - vectorType.getElementType(), vectorType.isScalable()); - return linearizedType; - }; - typeConverter.addConversion(convertType); + VectorType collapsedInType = + VectorType::get(collapsedInShape, extractOp.getType().getElementType()); - auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs, - Location loc) -> Value { - if (inputs.size() != 1) - return nullptr; + auto collapsedIn = rewriter.createOrFold( + extractOp.getLoc(), collapsedInType, adaptor.getVector()); - Value value = inputs.front(); - if (!isa(type) || !isa(value.getType())) - return nullptr; + auto replacement = rewriter.create( + extractOp.getLoc(), collapsedIn, collapsedOffsets, collapsedOutShape, + SmallVector(collapsedOffsets.size(), 1)); - return builder.create(loc, type, value); - }; - typeConverter.addSourceMaterialization(materializeCast); - typeConverter.addTargetMaterialization(materializeCast); + VectorType flatOutputType = + getTypeConverter()->convertType(extractOp.getType()); - target.markUnknownOpDynamicallyLegal( - [=](Operation *op) -> std::optional { - if (!isLinearizable(op)) - return true; - // This will return true if, for all operand and result types `t`, - // convertType(t) = t. This is true if there are no rank>=2 vectors. - return typeConverter.isLegal(op); - }); -} + Value out = rewriter.createOrFold( + extractOp.getLoc(), flatOutputType, replacement); -void mlir::vector::populateVectorLinearizeBasePatterns( - const TypeConverter &typeConverter, const ConversionTarget &target, - RewritePatternSet &patterns) { - patterns - .add( - typeConverter, patterns.getContext()); -} + rewriter.replaceOp(extractOp, out); + + return success(); + } +}; + +struct RankReduceInsertStridedSlice final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + RankReduceInsertStridedSlice(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto maybeCollapsed = getCollapsedInsertStridedSliceShape(insertOp); + + if (!maybeCollapsed.has_value()) + return failure(); + + const auto &[collapsedInShape, collapsedOutShape, collapsedOffsets] = + maybeCollapsed.value(); + + VectorType collapsedInType = + VectorType::get(collapsedInShape, insertOp.getType().getElementType()); + + Value collapsedIn = rewriter.createOrFold( + insertOp.getLoc(), collapsedInType, adaptor.getValueToStore()); + + VectorType collapsedOutType = + VectorType::get(collapsedOutShape, insertOp.getType().getElementType()); + + Value collapsedDst = rewriter.createOrFold( + insertOp.getLoc(), collapsedOutType, adaptor.getDest()); + + auto replacement = rewriter.create( + insertOp.getLoc(), collapsedIn, collapsedDst, collapsedOffsets, + SmallVector(collapsedOffsets.size(), 1)); + + Value out = rewriter.createOrFold( + insertOp.getLoc(), insertOp.getType(), replacement); + + rewriter.replaceOp(insertOp, out); + + return success(); + } +}; -void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( - const TypeConverter &typeConverter, const ConversionTarget &target, - RewritePatternSet &patterns) { - patterns.add(typeConverter, - patterns.getContext()); +void vector::VectorLinearizePatterns::addToPatternSet( + const TypeConverter &typeConverter, RewritePatternSet &patterns) const { + + MLIRContext *context = patterns.getContext(); + + if (isEnabled(LinearizePattern::LinearizeConstantLike)) + patterns.add( + typeConverter, context, + getBenefit(LinearizePattern::LinearizeConstantLike)); + + if (isEnabled(LinearizePattern::LinearizeVectorizable)) + patterns.add( + typeConverter, context, + getBenefit(LinearizePattern::LinearizeVectorizable)); + + if (isEnabled(LinearizePattern::LinearizeVectorBitCast)) + patterns.add( + typeConverter, context, + getBenefit(LinearizePattern::LinearizeVectorBitCast)); + + if (isEnabled(LinearizePattern::LinearizeVectorCreateMask)) + patterns.add( + typeConverter, context, + getBenefit(LinearizePattern::LinearizeVectorCreateMask)); + + if (isEnabled(LinearizePattern::LinearizeVectorShuffle)) + patterns.add( + typeConverter, context, + getBenefit(LinearizePattern::LinearizeVectorShuffle)); + + if (isEnabled(LinearizePattern::LinearizeVectorSplat)) + patterns.add( + typeConverter, context, + getBenefit(LinearizePattern::LinearizeVectorSplat)); + + // ------------------------ // + // Extract related patterns // + // ------------------------ // + if (isEnabled(LinearizePattern::VectorExtractToRankOneShuffle)) + patterns.add( + typeConverter, context, + getBenefit(LinearizePattern::VectorExtractToRankOneShuffle)); + + if (isEnabled(LinearizePattern::VectorExtractStridedSliceToRankOneShuffle)) + patterns.add( + typeConverter, context, + getBenefit( + LinearizePattern::VectorExtractStridedSliceToRankOneShuffle)); + + if (isEnabled(LinearizePattern::RankReduceExtractStridedSlice)) + patterns.add( + typeConverter, context, + getBenefit(LinearizePattern::RankReduceExtractStridedSlice)); + + if (isEnabled(LinearizePattern::VectorExtractToRankOneStrided)) + patterns.add( + typeConverter, context, + getBenefit(LinearizePattern::VectorExtractToRankOneStrided)); + + // ------------------------ // + // Insert related patterns // + // ------------------------ // + if (isEnabled(LinearizePattern::VectorInsertToRankOneShuffle)) + patterns.add( + typeConverter, context, + getBenefit(LinearizePattern::VectorInsertToRankOneShuffle)); + + if (isEnabled(LinearizePattern::VectorInsertStridedSliceToRankOneShuffle)) + patterns.add( + typeConverter, context, + getBenefit(LinearizePattern::VectorInsertStridedSliceToRankOneShuffle)); + + if (isEnabled(LinearizePattern::RankReduceInsertStridedSlice)) + patterns.add( + typeConverter, context, + getBenefit(LinearizePattern::RankReduceInsertStridedSlice)); + + if (isEnabled(LinearizePattern::VectorInsertToRankOneStrided)) + patterns.add( + typeConverter, context, + getBenefit(LinearizePattern::VectorInsertToRankOneStrided)); } diff --git a/mlir/test/Dialect/Vector/linearize/linearize-insert-extract-preference.mlir b/mlir/test/Dialect/Vector/linearize/linearize-insert-extract-preference.mlir new file mode 100644 index 0000000000000..a73602263d06e --- /dev/null +++ b/mlir/test/Dialect/Vector/linearize/linearize-insert-extract-preference.mlir @@ -0,0 +1,287 @@ +// Everything becomes a shuffle (except rank-1 insert/extract). +// RUN: mlir-opt %s -split-input-file -test-vector-linearize=preference=Shuffle | FileCheck %s --check-prefixes=SHUFFLE,ALL + +// RUN: mlir-opt %s -split-input-file -test-vector-linearize=preference=Strided | FileCheck %s --check-prefixes=STRIDED,ALL + + +// **--------------------------------------------------------** +// Tests of vector.insert +// **--------------------------------------------------------** + +// vector.insert where the destination is a 1D vector is always unchanged. +// +// ALL-LABEL: insert_scalar_to_1D( +// ALL-SAME: %[[A0:.*]]: i8, %[[A1:.*]]: vector<4xi8> +// ALL: %[[IN0:.*]] = vector.insert %[[A0]], %[[A1]] [2] : i8 into vector<4xi8> +// ALL: return %[[IN0]] : vector<4xi8> +func.func @insert_scalar_to_1D(%arg0 : i8, %arg1 : vector<4xi8>) -> vector<4xi8> +{ + %inserted = vector.insert %arg0, %arg1[2] : i8 into vector<4xi8> + return %inserted : vector<4xi8> +} + +// ----- + +// vector.insert of scalar always becomes insert of scalar into 1-D vector. +// +// ALL-LABEL: insert_scalar_to_2D( +// ALL-SAME: %[[A0:.*]]: i8, %[[A1:.*]]: vector<3x4xi8> +// ALL: %[[SC0:.*]] = vector.shape_cast %[[A1]] : vector<3x4xi8> to vector<12xi8> +// ALL: %[[IN0:.*]] = vector.insert %[[A0]], %[[SC0]] [9] : i8 into vector<12xi8> +// ALL: %[[SC1:.*]] = vector.shape_cast %[[IN0]] : vector<12xi8> to vector<3x4xi8> +// ALL: return %[[SC1]] : vector<3x4xi8> +func.func @insert_scalar_to_2D(%arg0 : i8, %arg1 : vector<3x4xi8>) -> vector<3x4xi8> +{ + %inserted = vector.insert %arg0, %arg1[2, 1] : i8 into vector<3x4xi8> + return %inserted : vector<3x4xi8> +} + +// ----- + +// vector.insert where the source isn't a scalar. First case: 1D -> 2D. +// +// ALL-LABEL: insert_1D_to_2D( +// +// SHUFFLE: vector.shuffle {{.*}} [0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11] +// +// STRIDED: vector.insert_strided_slice {{.*}} {offsets = [4], strides = [1]} +// STRIDED-SAME: vector<4xi8> into vector<12xi8> +func.func @insert_1D_to_2D(%arg0 : vector<4xi8>, %arg1 : vector<3x4xi8>) -> vector<3x4xi8> +{ + %inserted = vector.insert %arg0, %arg1[1] : vector<4xi8> into vector<3x4xi8> + return %inserted : vector<3x4xi8> +} + + +// ----- + +// vector.insert where the source isn't a scalar. Second case: 0D -> 2D. +// +// ALL-LABEL: insert_OD_to_2D( +// +// SHUFFLE: vector.shuffle {{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 10, 11] : +// SHUFFLE-SAME: vector<12xi8>, vector<1xi8> +// +// STRIDED: vector.insert_strided_slice {{.*}} {offsets = [9], strides = [1]} +// STRIDED-SAME: vector<1xi8> into vector<12xi8> +func.func @insert_OD_to_2D(%arg0 : vector, %arg1 : vector<3x4xi8>) -> vector<3x4xi8> +{ + %inserted = vector.insert %arg0, %arg1[2, 1] : vector into vector<3x4xi8> + return %inserted : vector<3x4xi8> +} + +// ----- + +// vector.insert where the source isn't a scalar. Third case: 0D -> 1D. +// +// ALL-LABEL: insert_OD_to_1D( +// ALL-SAME: %[[A0:.*]]: vector, %[[A1:.*]]: vector<4xi8> +// ALL: %[[IN0:.*]] = vector.insert %[[A0]], %[[A1]] [2] : vector into vector<4xi8> +// ALL: return %[[IN0]] : vector<4xi8> +func.func @insert_OD_to_1D(%arg0 : vector, %arg1 : vector<4xi8>) -> vector<4xi8> +{ + %inserted = vector.insert %arg0, %arg1[2] : vector into vector<4xi8> + return %inserted : vector<4xi8> +} + +// ----- + +// vector.insert where the source isn't a scalar. Fourth case: 2D -> 4D. +// +// ALL-LABEL: insert_2D_to_4D( +// ALL-COUNT-2: shape_cast +// +// SHUFFLE: vector.shuffle {{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19] : +// SHUFFLE-SAME: vector<16xi8>, vector<4xi8> +// +// STRIDED: vector.insert_strided_slice {{.*}} {offsets = [12], strides = [1]} +// STRIDED-SAME: vector<4xi8> into vector<16xi8> +func.func @insert_2D_to_4D(%arg0 : vector<2x2xi8>, %arg1 : vector<2x2x2x2xi8>) -> vector<2x2x2x2xi8> +{ + %inserted = vector.insert %arg0, %arg1[1, 1] : vector<2x2xi8> into vector<2x2x2x2xi8> + return %inserted : vector<2x2x2x2xi8> +} + +// ----- + +// **--------------------------------------------------------** +// Tests of vector.extract +// **--------------------------------------------------------** + +// vector.extract where the source is 1D vector is always unchanged. +// +// ALL-LABEL: extract_scalar_from_1D( +// ALL-SAME: %[[A0:.*]]: vector<4xi8> +// ALL: %[[EX0:.*]] = vector.extract %[[A0]][2] : i8 from vector<4xi8> +// ALL: return %[[EX0]] : i8 +func.func @extract_scalar_from_1D(%arg0 : vector<4xi8>) -> i8 +{ + %extracted = vector.extract %arg0[2] : i8 from vector<4xi8> + return %extracted : i8 +} + +// ALL-LABEL: extract_scalar_from_2D( +// ALL-SAME: %[[A0:.*]]: vector<3x4xi8> +// ALL: %[[SC0:.*]] = vector.shape_cast %[[A0]] : vector<3x4xi8> to vector<12xi8> +// ALL: %[[EX0:.*]] = vector.extract %[[SC0]][9] : i8 from vector<12xi8> +// ALL: return %[[EX0]] : i8 +func.func @extract_scalar_from_2D(%arg0 : vector<3x4xi8>) -> i8 +{ + %extracted = vector.extract %arg0[2, 1] : i8 from vector<3x4xi8> + return %extracted : i8 +} + +// ----- + +// ALL-LABEL: extract_1D_from_2D( +// +// SHUFFLE: vector.shuffle +// SHUFFLE-SAME: [4, 5, 6, 7] : vector<12xi8>, vector<12xi8> +// +// STRIDED: vector.extract_strided_slice +// STRIDED-SAME: {offsets = [4], sizes = [4], strides = [1]} : vector<12xi8> to vector<4xi8> +func.func @extract_1D_from_2D(%arg0 : vector<3x4xi8>) -> vector<4xi8> +{ + %extracted = vector.extract %arg0[1] : vector<4xi8> from vector<3x4xi8> + return %extracted : vector<4xi8> +} + +// ----- + +// ALL-LABEL: extract_2D_from_4D( +// +// SHUFFLE: vector.shuffle +// SHUFFLE-SAME: [10, 11] : vector<24xi8>, vector<24xi8> +// +// STRIDED: vector.extract_strided_slice +// STRIDED-SAME: {offsets = [10], sizes = [2], strides = [1]} : vector<24xi8> to vector<2xi8> +func.func @extract_2D_from_4D(%arg0 : vector<4x3x2x1xi8>) -> vector<2x1xi8> { + %extracted = vector.extract %arg0[1, 2] : vector<2x1xi8> from vector<4x3x2x1xi8> + return %extracted : vector<2x1xi8> +} + +// **--------------------------------------------------------** +// Tests of vector.insert_strided_slice +// **--------------------------------------------------------** + +// ----- + +// ALL-LABEL: insert_strided_slice_1D( +// +// SHUFFLE: shuffle {{.*}} [0, 8, 9, 3, 4, 5, 6, 7] +// +// STRIDED: insert_strided_slice {{.*}} {offsets = [1], strides = [1]} +func.func @insert_strided_slice_1D(%arg0 : vector<2xi8>, %arg1 : vector<8xi8>) -> vector<8xi8> +{ + %inserted = vector.insert_strided_slice %arg0, %arg1 {offsets = [1], strides = [1]} : vector<2xi8> into vector<8xi8> + return %inserted : vector<8xi8> +} + +// ----- + +// ALL-LABEL: insert_strided_slice_4D_contiguous( +// +// SHUFFLE: vector.shuffle +// SHUFFLE-SAME: 52, 53, 120, 121 +// SHUFFLE-SAME: 130, 131, 66, 67 +// SHUFFLE-SAME: vector<120xi8>, vector<12xi8> +// +// STRIDED: vector.insert_strided_slice +// STRIDED-SAME: {offsets = [54], strides = [1]} +// STRIDED-SAME: vector<12xi8> into vector<120xi8> + + +func.func @insert_strided_slice_4D_contiguous(%arg0 : vector<1x2x2x3xi8>, %arg1 : vector<5x4x2x3xi8>) -> vector<5x4x2x3xi8> { + %inserted = vector.insert_strided_slice %arg0, %arg1 {offsets = [2, 1, 0, 0], strides = [1, 1, 1, 1]} : vector<1x2x2x3xi8> into vector<5x4x2x3xi8> + return %inserted : vector<5x4x2x3xi8> +} + +// ----- + +// This insert_strided_slice is not contiguous, and so it is always linearized to a 1D vector.shuffle + +// ALL-LABEL: insert_strided_slice_4D_noncontiguous( +// ALL: vector.shuffle +// ALL-SAME: [0, 1, 2, 8, 4, 5, 6, 9] : vector<8xi8>, vector<2xi8> + +func.func @insert_strided_slice_4D_noncontiguous(%arg0 : vector<1x2x1x1xi8>, %arg1 : vector<1x2x2x2xi8>) -> vector<1x2x2x2xi8> { + %inserted = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 1, 1], strides = [1, 1, 1, 1]} : vector<1x2x1x1xi8> into vector<1x2x2x2xi8> + return %inserted : vector<1x2x2x2xi8> +} + +// ----- + +// **--------------------------------------------------------** +// Tests of vector.extract_strided_slice +// **--------------------------------------------------------** + +// ALL-LABEL: extract_strided_slice_1D( +// +// SHUFFLE: vector.shuffle {{.*}} [1, 2] +// +// STRIDED: vector.extract_strided_slice +// STRIDED-SAME: {offsets = [1], sizes = [2], strides = [1]} +// STRIDED-SAME: vector<8xi8> to vector<2xi8> +func.func @extract_strided_slice_1D(%arg0 : vector<8xi8>) -> vector<2xi8> +{ + %extracted = vector.extract_strided_slice %arg0 {offsets = [1], sizes = [2], strides = [1]} : vector<8xi8> to vector<2xi8> + return %extracted : vector<2xi8> +} + +// ----- + +// ALL-LABEL: extract_strided_slice_4D_contiguous_1( +// +// SHUFFLE: vector.shuffle +// SHUFFLE-SAME: [3, 4, 5] +// SHUFFLE-SAME: vector<6xi8>, vector<6xi8> +// +// STRIDED: vector.extract_strided_slice +// STRIDED-SAME: {offsets = [3], sizes = [3], strides = [1]} +// STRIDED-SAME: vector<6xi8> to vector<3xi8> +func.func @extract_strided_slice_4D_contiguous_1(%arg0 : vector<2x1x3x1xi8>) -> vector<1x1x3x1xi8> { + %extracted = vector.extract_strided_slice %arg0 {offsets = [1, 0, 0, 0], sizes = [1, 1, 3, 1], strides = [1, 1, 1, 1]} : vector<2x1x3x1xi8> to vector<1x1x3x1xi8> + return %extracted : vector<1x1x3x1xi8> +} + +// ----- + +// ALL-LABEL: extract_strided_slice_4D_contiguous_2( +// +// SHUFFLE: vector.shuffle +// SHUFFLE-SAME: [3, 4] +// SHUFFLE-SAME: vector<6xi8>, vector<6xi8> +// +// STRIDED: vector.extract_strided_slice +// STRIDED-SAME: {offsets = [3], sizes = [2], strides = [1]} +// STRIDED-SAME: vector<6xi8> to vector<2xi8> +func.func @extract_strided_slice_4D_contiguous_2(%arg0 : vector<2x1x3x1xi8>) -> vector<1x1x2x1xi8> { + %extracted = vector.extract_strided_slice %arg0 {offsets = [1, 0, 0, 0], sizes = [1, 1, 2, 1], strides = [1, 1, 1, 1]} : vector<2x1x3x1xi8> to vector<1x1x2x1xi8> + return %extracted : vector<1x1x2x1xi8> +} + +// ----- + +// ALL-LABEL: extract_strided_slice_4D_noncontiguous( +// ALL: vector.shuffle +// ALL-SAME: [0, 1, 3, 4] +// ALL-SAME: vector<6xi8>, vector<6xi8> +func.func @extract_strided_slice_4D_noncontiguous(%arg0 : vector<2x1x3x1xi8>) -> vector<2x1x2x1xi8> { + %extracted = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0], sizes = [2, 1, 2, 1], strides = [1, 1, 1, 1]} : vector<2x1x3x1xi8> to vector<2x1x2x1xi8> + return %extracted : vector<2x1x2x1xi8> +} + + + + + + + + + + + + + + + diff --git a/mlir/test/Dialect/Vector/linearize-subject-to-bitwidth.mlir b/mlir/test/Dialect/Vector/linearize/linearize-subject-to-bitwidth.mlir similarity index 100% rename from mlir/test/Dialect/Vector/linearize-subject-to-bitwidth.mlir rename to mlir/test/Dialect/Vector/linearize/linearize-subject-to-bitwidth.mlir diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize/linearize.mlir similarity index 99% rename from mlir/test/Dialect/Vector/linearize.mlir rename to mlir/test/Dialect/Vector/linearize/linearize.mlir index 9cbf319ffddb2..b7a5448c7dc22 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize/linearize.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s -split-input-file -test-vector-linearize=preference=Shuffle -verify-diagnostics | FileCheck %s // CHECK-LABEL: test_linearize // CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>) @@ -297,7 +297,6 @@ func.func @test_vector_extract_scalable(%arg0: vector<2x8x[2]xf32>) -> vector<8x // CHECK-LABEL: test_vector_insert // CHECK-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> { func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> { - // CHECK-DAG: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32> // CHECK-DAG: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32> // CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]] diff --git a/mlir/test/Dialect/Vector/linearize/rank-reduce-strided-ops.mlir b/mlir/test/Dialect/Vector/linearize/rank-reduce-strided-ops.mlir new file mode 100644 index 0000000000000..342936b0f7eed --- /dev/null +++ b/mlir/test/Dialect/Vector/linearize/rank-reduce-strided-ops.mlir @@ -0,0 +1,135 @@ +// RUN: mlir-opt %s -split-input-file -test-rank-reduce-strided-slice-ops -verify-diagnostics | FileCheck %s + + +// **---------------------------------------------** +// Tests of vector.extract_strided_slice +// **---------------------------------------------** + +// The 6 elements extracted are contiguous, so this can be expressed as a rank-1 vector.extract_strided_slice. + +// CHECK-LABEL: @extract_strided_slice_2D_to_1D( +// CHECK-SAME: %[[A:.*]]: vector<5x2xi8>) -> vector<3x2xi8> { +// CHECK: %[[SC:.*]] = vector.shape_cast %[[A]] : vector<5x2xi8> to vector<10xi8> +// CHECK: %[[EXTRACTED:.*]] = vector.extract_strided_slice %[[SC]] +// CHECK-SAME: {offsets = [2], sizes = [6], strides = [1]} : vector<10xi8> to vector<6xi8> +// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[EXTRACTED]] : vector<6xi8> to vector<3x2xi8> +// CHECK: return %[[CASTED]] : vector<3x2xi8> +func.func @extract_strided_slice_2D_to_1D(%arg0 : vector<5x2xi8>) -> vector<3x2xi8> { + %extracted = vector.extract_strided_slice %arg0 {offsets = [1, 0], sizes = [3, 2], strides = [1, 1]} : vector<5x2xi8> to vector<3x2xi8> + return %extracted : vector<3x2xi8> +} + +// ----- + +// The 5 elements extracted are not contiguous, so this cannot be expressed as a rank-1 vector.extract_strided_slice. + +// CHECK-LABEL: @negative_extract_strided_slice_2D_to_1D( +// CHECK-SAME: %[[A:.*]]: vector<5x2xi8>) -> vector<5x1xi8> { +// CHECK: %[[EXTRACTED:.*]] = vector.extract_strided_slice %[[A]] +// CHECK: return %[[EXTRACTED]] : vector<5x1xi8> +func.func @negative_extract_strided_slice_2D_to_1D(%arg0 : vector<5x2xi8>) -> vector<5x1xi8> { + %extracted = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [5, 1], strides = [1, 1]} : vector<5x2xi8> to vector<5x1xi8> + return %extracted : vector<5x1xi8> +} + +// ----- + +// The 2 elements extracted are contiguous, so this can be expressed as a rank-1 vector.extract_strided_slice. + +// CHECK-LABEL: @extract_strided_slice_4D_leading_ones( +// CHECK-SAME: %[[A:.*]]: vector<2x1x3x1xi8>) -> vector<1x1x2x1xi8> { +// CHECK: %[[SC:.*]] = vector.shape_cast %[[A]] : vector<2x1x3x1xi8> to vector<6xi8> +// CHECK: %[[EXTRACTED:.*]] = vector.extract_strided_slice %[[SC]] +// CHECK-SAME: {offsets = [3], sizes = [2], strides = [1]} : vector<6xi8> to vector<2xi8> +// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[EXTRACTED]] : vector<2xi8> to vector<1x1x2x1xi8> +// CHECK: return %[[CASTED]] : vector<1x1x2x1xi8> + +func.func @extract_strided_slice_4D_leading_ones(%arg0 : vector<2x1x3x1xi8>) -> vector<1x1x2x1xi8> { + %extracted = vector.extract_strided_slice %arg0 {offsets = [1, 0, 0, 0], sizes = [1, 1, 2, 1], strides = [1, 1, 1, 1]} : vector<2x1x3x1xi8> to vector<1x1x2x1xi8> + return %extracted : vector<1x1x2x1xi8> +} + +// ----- + +// CHECK-LABEL: @extract_strided_slice_4D_becomes_2D( +// CHECK-SAME: %[[A:.*]]: vector<8x7x6x5xi8>) -> vector<2x7x2x5xi8> { +// CHECK: %[[SC:.*]] = vector.shape_cast %[[A]] : vector<8x7x6x5xi8> to vector<56x30xi8> +// CHECK: %[[EXTRACTED:.*]] = vector.extract_strided_slice %[[SC]] +// CHECK-SAME: {offsets = [14, 5], sizes = [14, 10], strides = [1, 1]} : vector<56x30xi8> to vector<14x10xi8> +// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[EXTRACTED]] : vector<14x10xi8> to vector<2x7x2x5xi8> +// CHECK: return %[[CASTED]] : vector<2x7x2x5xi8> +func.func @extract_strided_slice_4D_becomes_2D(%arg0 : vector<8x7x6x5xi8>) -> vector<2x7x2x5xi8> { + %extracted = vector.extract_strided_slice %arg0 {offsets = [2, 0, 1, 0], sizes = [2, 7, 2, 5], strides = [1, 1, 1, 1]} : vector<8x7x6x5xi8> to vector<2x7x2x5xi8> + return %extracted : vector<2x7x2x5xi8> +} + +// ----- + +// CHECK-LABEL: @extract_strided_slice_4D_becomes_3D( +// CHECK-SAME: %[[A:.*]]: vector<8x7x6x5xi8>) -> vector<8x2x6x2xi8> { + // CHECK: %[[SC:.*]] = vector.shape_cast %[[A]] : vector<8x7x6x5xi8> to vector<8x42x5xi8> + // CHECK: %[[EXTRACTED:.*]] = vector.extract_strided_slice %[[SC]] + // CHECK-SAME: {offsets = [0, 12, 1], sizes = [8, 12, 2], strides = [1, 1, 1]} : vector<8x42x5xi8> to vector<8x12x2xi8> + // CHECK: %[[CASTED:.*]] = vector.shape_cast %[[EXTRACTED]] : vector<8x12x2xi8> to vector<8x2x6x2xi8> + // CHECK: return %[[CASTED]] : vector<8x2x6x2xi8> + +func.func @extract_strided_slice_4D_becomes_3D(%arg0 : vector<8x7x6x5xi8>) -> vector<8x2x6x2xi8> { + %extracted = vector.extract_strided_slice %arg0 {offsets = [0, 2, 0, 1], sizes = [8, 2, 6, 2], strides = [1, 1, 1, 1]} : vector<8x7x6x5xi8> to vector<8x2x6x2xi8> + return %extracted : vector<8x2x6x2xi8> +} + +// ----- + +// **---------------------------------------------** +// Tests of vector.insert_strided_slice +// **---------------------------------------------** + + +// CHECK-LABEL: @negative_insert_strided_slice( +// CHECK-SAME: %[[A:.*]]: vector<2x2xi8>, %[[B:.*]]: vector<2x1xi8>) -> vector<2x2xi8> { +// CHECK: %[[INSERTED:.*]] = vector.insert_strided_slice %[[B]], %[[A]] +// CHECK: return %[[INSERTED]] : vector<2x2xi8> +func.func @negative_insert_strided_slice(%arg0 : vector<2x2xi8>, %arg1 : vector<2x1xi8>) -> vector<2x2xi8> { + %inserted = vector.insert_strided_slice %arg1, %arg0 {offsets = [0, 1], strides = [1, 1]} : vector<2x1xi8> into vector<2x2xi8> + return %inserted : vector<2x2xi8> +} + +// ----- + +// CHECK-LABEL: @positive_insert_strided_slice( +// CHECK-SAME: %[[A:.*]]: vector<2x2xi8>, %[[B:.*]]: vector<1x2xi8>) -> vector<2x2xi8> { +// CHECK-DAG: %[[SCA:.*]] = vector.shape_cast %[[A]] : vector<2x2xi8> to vector<4xi8> +// CHECK-DAG: %[[SCB:.*]] = vector.shape_cast %[[B]] : vector<1x2xi8> to vector<2xi8> +// CHECK: %[[INSERTED:.*]] = vector.insert_strided_slice %[[SCB]], %[[SCA]] +// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xi8> into vector<4xi8> +// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[INSERTED]] : vector<4xi8> to vector<2x2xi8> +// CHECK: return %[[CASTED]] : vector<2x2xi8> + +func.func @positive_insert_strided_slice(%arg0 : vector<2x2xi8>, %arg1 : vector<1x2xi8>) -> vector<2x2xi8> { + %inserted = vector.insert_strided_slice %arg1, %arg0 {offsets = [0, 0], strides = [1, 1]} : vector<1x2xi8> into vector<2x2xi8> + return %inserted : vector<2x2xi8> +} + +// ----- + +func.func @test_extract_strided_slice_4D(%arg0 : vector<2x2x2x2xi8>) -> vector<1x2x1x2xi8> { + %0 = vector.extract_strided_slice %arg0 + {offsets = [1, 0, 1, 0], + sizes = [1, 2, 1, 2], + strides = [1, 1, 1, 1]} : vector<2x2x2x2xi8> to vector<1x2x1x2xi8> + return %0 : vector<1x2x1x2xi8> +} + +// ----- + +// Equivent to the above but now with an insert strided slice. + + +func.func @test_insert_strided_slice_4D(%arg0 : vector<2x2x2x2xi8>, %arg1 : vector<1x2x1x2xi8>) -> vector<2x2x2x2xi8> { + %0 = vector.insert_strided_slice %arg1, %arg0 + {offsets = [1, 0, 1, 0], + strides = [1, 1, 1, 1]} : vector<1x2x1x2xi8> into vector<2x2x2x2xi8> + return %0 : vector<2x2x2x2xi8> +} + + diff --git a/mlir/test/lib/Dialect/Vector/CMakeLists.txt b/mlir/test/lib/Dialect/Vector/CMakeLists.txt index e16937029ac0e..1ce069599af43 100644 --- a/mlir/test/lib/Dialect/Vector/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Vector/CMakeLists.txt @@ -1,6 +1,7 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRVectorTestPasses TestVectorTransforms.cpp + TestVectorLinearize.cpp EXCLUDE_FROM_LIBMLIR ) diff --git a/mlir/test/lib/Dialect/Vector/TestVectorLinearize.cpp b/mlir/test/lib/Dialect/Vector/TestVectorLinearize.cpp new file mode 100644 index 0000000000000..21d6356bf04fa --- /dev/null +++ b/mlir/test/lib/Dialect/Vector/TestVectorLinearize.cpp @@ -0,0 +1,248 @@ +//===- TestVectorLinearize.cpp - Test Vector linearization ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Math//IR/Math.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" +#include "mlir/Dialect/Vector/Transforms/VectorLinearize.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" + +using namespace mlir; +using namespace mlir::vector; + +namespace { + +struct TestVectorLinearize final + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize) + + TestVectorLinearize() = default; + + TestVectorLinearize(const TestVectorLinearize &pass) : PassWrapper(pass) {} + + StringRef getArgument() const override { return "test-vector-linearize"; } + StringRef getDescription() const override { + return "Linearizes ND vectors for N >= 2 into 1D vectors"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + Option preference{ + *this, "preference", + llvm::cl::desc("Corresponds 1:1 with InsertExtractLinearizePreference"), + llvm::cl::values( + clEnumValN( + static_cast(InsertExtractLinearizePreference::Strided), + "Strided", ""), + clEnumValN( + static_cast(InsertExtractLinearizePreference::Shuffle), + "Shuffle", ""))}; + + void runOnOperation() override { + MLIRContext &context = getContext(); + TypeConverter converter; + RewritePatternSet patterns(&context); + ConversionTarget target(context); + initializeForVectorLinearize(converter); + populateForFullVectorLinearize(converter, target, patterns, preference); + + mlir::scf::populateSCFStructuralTypeConversionsAndLegality( + converter, patterns, target); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; + +struct TestRankReduceStridedSliceOps final + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRankReduceStridedSliceOps) + + TestRankReduceStridedSliceOps() = default; + + StringRef getArgument() const override { + return "test-rank-reduce-strided-slice-ops"; + } + StringRef getDescription() const override { + return "Test pass for rank-reducing strided slice ops."; + } + + void runOnOperation() override { + MLIRContext &context = getContext(); + TypeConverter typeConverter; + RewritePatternSet patterns(&context); + ConversionTarget target(context); + + VectorLinearizePatterns() + .enableAll(false) + .enable(LinearizePattern::RankReduceInsertStridedSlice) + .enable(LinearizePattern::RankReduceExtractStridedSlice) + .addToPatternSet(typeConverter, patterns); + + typeConverter.addConversion( + [](Type t) -> std::optional { return t; }); + + typeConverter.addSourceMaterialization( + [](OpBuilder &builder, Type type, ValueRange inputs, + Location loc) -> Value { return inputs.front(); }); + + target.markUnknownOpDynamicallyLegal( + [&](Operation *op) -> std::optional { + if (auto insertOp = dyn_cast(op)) { + return !getCollapsedInsertStridedSliceShape(insertOp).has_value(); + } + if (auto extractOp = dyn_cast(op)) { + return !getCollapsedExtractStridedSliceShape(extractOp).has_value(); + } + return true; + }); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; + +struct TestVectorBitWidthLinearize final + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBitWidthLinearize) + + TestVectorBitWidthLinearize() = default; + TestVectorBitWidthLinearize(const TestVectorBitWidthLinearize &pass) + : PassWrapper(pass) {} + + StringRef getArgument() const override { + return "test-bit-width-constrained-vector-linearize"; + } + StringRef getDescription() const override { + return "Linearizes ND vectors for N >= 2 into 1D vectors, with constraints " + "on inner-most dimension's bit width. If the inner-most dimension " + "exceded a threshold, the op is not linearized."; + } + Option targetVectorBitwidth{ + *this, "target-vector-bitwidth", + llvm::cl::desc( + "Minimum vector bitwidth to enable the flattening transformation"), + llvm::cl::init(std::numeric_limits::max())}; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + MLIRContext &context = getContext(); + TypeConverter typeConverter; + RewritePatternSet patterns(&context); + ConversionTarget target(context); + populateWithBitWidthConstraints(typeConverter, target, patterns, + targetVectorBitwidth); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } + +private: + /// If `type` is VectorType with trailing dimension of (bit) size greater than + /// or equal to `targetBitWidth`, its defining op is considered legal. + static bool + isNotLinearizableBecauseLargeInnerDimension(Type type, + unsigned targetBitWidth) { + + VectorType vecType = dyn_cast(type); + + // Not linearizable for reasons other than what this function checks. + if (!vecType || vecType.getRank() == 0) + return false; + + // The width of the type 'index' is unbounded (and therefore potentially + // above the target width). + if (vecType.getElementType().isIndex()) + return true; + + unsigned finalDimSize = vecType.getShape().back(); + unsigned nbBitsPerElm = vecType.getElementTypeBitWidth(); + unsigned trailingVecDimBitWidth = finalDimSize * nbBitsPerElm; + return trailingVecDimBitWidth >= targetBitWidth; + } + + static bool + isNotLinearizableBecauseLargeInnerDimension(Operation *op, + unsigned targetBitWidth) { + // Check on bitwidths. + SmallVector> toCheck = + getTypeBitWidthBoundPairs(op, targetBitWidth); + return std::any_of(toCheck.begin(), toCheck.end(), + [&](std::pair typeWidth) { + return isNotLinearizableBecauseLargeInnerDimension( + typeWidth.first, typeWidth.second); + }); + } + + static void populateWithBitWidthConstraints(TypeConverter &typeConverter, + ConversionTarget &target, + RewritePatternSet &patterns, + unsigned targetBitWidth) { + + initializeForVectorLinearize(typeConverter); + populateForFullVectorLinearize(typeConverter, target, patterns, + InsertExtractLinearizePreference::Shuffle); + + // Extend the set of legal ops to include those with large inner-most + // dimensions on selected operands/results. + target.markUnknownOpDynamicallyLegal( + [=](Operation *op) -> std::optional { + if (isNotLinearizableBecauseLargeInnerDimension(op, targetBitWidth)) { + return true; + } + return {}; + }); + } + + /// Get the set of operand/result types to check for sufficiently + /// small inner-most dimension size. + static SmallVector> + getTypeBitWidthBoundPairs(Operation *op, unsigned targetBitWidth) { + + if (auto insertOp = dyn_cast(op)) { + unsigned w = targetBitWidth < std::numeric_limits::max() + ? targetBitWidth + 1 + : targetBitWidth; + return {{insertOp.getValueToStoreType(), w}}; + } + + auto resultTypes = op->getResultTypes(); + SmallVector> resultsWithBitWidth; + resultsWithBitWidth.reserve(resultTypes.size()); + for (Type type : resultTypes) { + resultsWithBitWidth.push_back({type, targetBitWidth}); + } + return resultsWithBitWidth; + } +}; + +} // namespace + +namespace mlir { +namespace test { +extern void registerTestVectorLinearize() { + PassRegistration(); + PassRegistration(); + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index f4f32e9339870..5c75d32c22236 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -17,7 +17,6 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" @@ -837,160 +836,6 @@ struct TestVectorEmulateMaskedLoadStore final } }; -/// Get the set of operand/result types to check for sufficiently -/// small inner-most dimension size. -static SmallVector> -getTypeBitWidthBoundPairs(Operation *op, unsigned targetBitWidth) { - - if (auto insertOp = dyn_cast(op)) { - unsigned w = targetBitWidth < std::numeric_limits::max() - ? targetBitWidth + 1 - : targetBitWidth; - return {{insertOp.getValueToStoreType(), w}}; - } - - auto resultTypes = op->getResultTypes(); - SmallVector> resultsWithBitWidth; - resultsWithBitWidth.reserve(resultTypes.size()); - for (Type type : resultTypes) { - resultsWithBitWidth.push_back({type, targetBitWidth}); - } - return resultsWithBitWidth; -} - -/// If `type` is VectorType with trailing dimension of (bit) size greater than -/// or equal to `targetBitWidth`, its defining op is considered legal. -static bool -isNotLinearizableBecauseLargeInnerDimension(Type type, - unsigned targetBitWidth) { - - VectorType vecType = dyn_cast(type); - - // Not linearizable for reasons other than what this function checks. - if (!vecType || vecType.getRank() == 0) - return false; - - // The width of the type 'index' is unbounded (and therefore potentially above - // the target width). - if (vecType.getElementType().isIndex()) - return true; - - unsigned finalDimSize = vecType.getShape().back(); - unsigned nbBitsPerElm = vecType.getElementTypeBitWidth(); - unsigned trailingVecDimBitWidth = finalDimSize * nbBitsPerElm; - return trailingVecDimBitWidth >= targetBitWidth; -} - -static bool -isNotLinearizableBecauseLargeInnerDimension(Operation *op, - unsigned targetBitWidth) { - // Check on bitwidths. - SmallVector> toCheck = - getTypeBitWidthBoundPairs(op, targetBitWidth); - return llvm::any_of(toCheck, [&](std::pair typeWidth) { - return isNotLinearizableBecauseLargeInnerDimension(typeWidth.first, - typeWidth.second); - }); -} - -void populateWithBitWidthConstraints(TypeConverter &typeConverter, - ConversionTarget &target, - unsigned targetBitWidth) { - - // The general purpose definition of what ops are legal must come first. - populateForVectorLinearize(typeConverter, target); - - // Extend the set of legal ops to include those with large inner-most - // dimensions on selected operands/results. - target.markUnknownOpDynamicallyLegal( - [=](Operation *op) -> std::optional { - if (isNotLinearizableBecauseLargeInnerDimension(op, targetBitWidth)) { - return true; - } - return {}; - }); -} - -struct TestVectorBitWidthLinearize final - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBitWidthLinearize) - - TestVectorBitWidthLinearize() = default; - TestVectorBitWidthLinearize(const TestVectorBitWidthLinearize &pass) - : PassWrapper(pass) {} - - StringRef getArgument() const override { - return "test-bit-width-constrained-vector-linearize"; - } - StringRef getDescription() const override { - return "Linearizes ND vectors for N >= 2 into 1D vectors, with constraints " - "in inner-most dimension's bit width."; - } - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - Option targetVectorBitwidth{ - *this, "target-vector-bitwidth", - llvm::cl::desc( - "Minimum vector bitwidth to enable the flattening transformation"), - llvm::cl::init(std::numeric_limits::max())}; - void runOnOperation() override { - auto *context = &getContext(); - - TypeConverter typeConverter; - RewritePatternSet patterns(context); - ConversionTarget target(*context); - - populateWithBitWidthConstraints(typeConverter, target, - targetVectorBitwidth); - - vector::populateVectorLinearizeBasePatterns(typeConverter, target, - patterns); - - vector::populateVectorLinearizeShuffleLikeOpsPatterns(typeConverter, target, - patterns); - - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) - return signalPassFailure(); - } -}; - -struct TestVectorLinearize final - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize) - - TestVectorLinearize() = default; - - StringRef getArgument() const override { return "test-vector-linearize"; } - StringRef getDescription() const override { - return "Linearizes ND vectors for N >= 2 into 1D vectors"; - } - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - MLIRContext &context = getContext(); - TypeConverter converter; - RewritePatternSet patterns(&context); - ConversionTarget target(context); - - vector::populateForVectorLinearize(converter, target); - - vector::populateVectorLinearizeBasePatterns(converter, target, patterns); - vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target, - patterns); - mlir::scf::populateSCFStructuralTypeConversionsAndLegality( - converter, patterns, target); - - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) - return signalPassFailure(); - } -}; - struct TestEliminateVectorMasks : public PassWrapper> { @@ -1062,10 +907,6 @@ void registerTestVectorLowerings() { PassRegistration(); - PassRegistration(); - - PassRegistration(); - PassRegistration(); } } // namespace test diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 2e08ae6f37980..f52f36107e301 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -155,6 +155,7 @@ void registerTestTopologicalSortAnalysisPass(); void registerTestTransformDialectEraseSchedulePass(); void registerTestPassStateExtensionCommunication(); void registerTestVectorLowerings(); +void registerTestVectorLinearize(); void registerTestVectorReductionToSPIRVDotProd(); void registerTestVulkanRunnerPipeline(); void registerTestWrittenToPass(); @@ -300,6 +301,7 @@ void registerTestPasses() { mlir::test::registerTestTransformDialectEraseSchedulePass(); mlir::test::registerTestPassStateExtensionCommunication(); mlir::test::registerTestVectorLowerings(); + mlir::test::registerTestVectorLinearize(); mlir::test::registerTestVectorReductionToSPIRVDotProd(); mlir::test::registerTestVulkanRunnerPipeline(); mlir::test::registerTestWrittenToPass();