From 207b0e11365b61db78896741d17d2505f8c3f891 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 20 May 2025 15:25:12 -0700 Subject: [PATCH 1/2] first commit in preparation --- .../mlir/Dialect/Vector/IR/VectorOps.td | 4 + mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 114 ++++++++++++++++ .../Vector/Transforms/VectorLinearize.cpp | 126 ++---------------- 3 files changed, 132 insertions(+), 112 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 3f5564541554e..f8412863b18c9 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1114,6 +1114,8 @@ def Vector_InsertStridedSliceOp : return ::llvm::cast(attr).getInt() != 1; }); } + // \return The indices in dest that the values are inserted to. + FailureOr> getLinearIndices(); }]; let hasFolder = 1; @@ -1254,6 +1256,8 @@ def Vector_ExtractStridedSliceOp : return ::llvm::cast(attr).getInt() != 1; }); } + // \return The indices in source that the values are taken from. + FailureOr> getLinearIndices(); }]; let hasCanonicalizer = 1; let hasFolder = 1; diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 41777347975da..e800b7b7c9ff6 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -3182,6 +3182,101 @@ void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result, stridesAttr); } +/// Convert an array of attributes into a vector of integers, if possible. +static FailureOr> intsFromArrayAttr(ArrayAttr attrs) { + if (!attrs) + return failure(); + SmallVector ints; + ints.reserve(attrs.size()); + for (auto attr : attrs) { + if (auto intAttr = dyn_cast(attr)) { + ints.push_back(intAttr.getInt()); + } else { + return failure(); + } + } + return ints; +} + +/// Consider inserting a vector of shape `small` into a vector of shape `large`, +/// at position `offsets`: this function enumeratates all the indices in `large` +/// that are written to. The enumeration is with row-major ordering. +/// +/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2 +/// positions written to are (1,3) and (1,4), which have linearized indices 8 +/// and 9. So [8,9] is returned. +/// +/// The length of the returned vector is equal to the number of elements in +/// the shape `small` (i.e. the product of dimensions of `small`). +static SmallVector +getStridedSliceInsertionIndices(ArrayRef small, + ArrayRef large, + ArrayRef offsets) { + + // Example of alignment between, `large`, `small` and `offsets`: + // large = 4, 5, 6, 7, 8 + // small = 1, 6, 7, 8 + // offsets = 2, 3, 0 + // + // `offsets` has implicit trailing 0s, `small` has implicit leading 1s. + assert((large.size() >= small.size()) && + "rank of 'large' cannot be lower than rank of 'small'"); + assert((large.size() >= offsets.size()) && + "rank of 'large' cannot be lower than the number of offsets"); + unsigned delta = large.size() - small.size(); + unsigned nOffsets = offsets.size(); + auto getSmall = [&](int64_t i) -> int64_t { + return i >= delta ? small[i - delta] : 1; + }; + auto getOffset = [&](int64_t i) -> int64_t { + return i < nOffsets ? offsets[i] : 0; + }; + + // Using 2 vectors of indices, at each iteration populate the updated set of + // indices based on the old set of indices, and the size of the small vector + // in the current iteration. + SmallVector indices{0}; + int64_t stride = 1; + for (int i = large.size() - 1; i >= 0; --i) { + int64_t currentSize = indices.size(); + int64_t smallSize = getSmall(i); + int64_t nextSize = currentSize * smallSize; + SmallVector nextIndices(nextSize); + int64_t *base = nextIndices.begin(); + int64_t offset = getOffset(i) * stride; + for (int j = 0; j < smallSize; ++j) { + for (int k = 0; k < currentSize; ++k) { + base[k] = indices[k] + offset; + } + offset += stride; + base += currentSize; + } + stride *= large[i]; + indices = std::move(nextIndices); + } + return indices; +} + +FailureOr> InsertStridedSliceOp::getLinearIndices() { + + // Stride > 1 to be considered if/when the insert_strided_slice supports it. + if (hasNonUnitStrides()) + return failure(); + + // Only when the destination has a static size can the indices be enumerated. + if (getType().isScalable()) + return failure(); + + // Only when the offsets are all static can the indices be enumerated. + FailureOr> offsets = intsFromArrayAttr(getOffsets()); + if (failed(offsets)) + return failure(); + + return getStridedSliceInsertionIndices(getSourceVectorType().getShape(), + getDestVectorType().getShape(), + offsets.value()); +} + // TODO: Should be moved to Tablegen ConfinedAttr attributes. template static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, @@ -3638,6 +3733,25 @@ void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result, stridesAttr); } +FailureOr> ExtractStridedSliceOp::getLinearIndices() { + + // Stride > 1 to be considered if/when extract_strided_slice supports it. + if (hasNonUnitStrides()) + return failure(); + + // Only when the source has a static size can the indices be enumerated. + if (getSourceVectorType().isScalable()) + return failure(); + + // Only when the offsets are all static can the indices be enumerated. + FailureOr> offsets = intsFromArrayAttr(getOffsets()); + if (failed(offsets)) + return failure(); + + return getStridedSliceInsertionIndices( + getType().getShape(), getSourceVectorType().getShape(), offsets.value()); +} + LogicalResult ExtractStridedSliceOp::verify() { auto type = getSourceVectorType(); auto offsets = getOffsetsAttr(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 678a88627ca82..6cf818bbd0695 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -109,90 +109,6 @@ struct LinearizeVectorizable final } }; -template -static bool stridesAllOne(TOp op) { - static_assert( - std::is_same_v || - std::is_same_v, - "expected vector.extract_strided_slice or vector.insert_strided_slice"); - ArrayAttr strides = op.getStrides(); - return llvm::all_of(strides, isOneInteger); -} - -/// Convert an array of attributes into a vector of integers, if possible. -static FailureOr> intsFromArrayAttr(ArrayAttr attrs) { - if (!attrs) - return failure(); - SmallVector ints; - ints.reserve(attrs.size()); - for (auto attr : attrs) { - if (auto intAttr = dyn_cast(attr)) { - ints.push_back(intAttr.getInt()); - } else { - return failure(); - } - } - return ints; -} - -/// Consider inserting a vector of shape `small` into a vector of shape `large`, -/// at position `offsets`: this function enumeratates all the indices in `large` -/// that are written to. The enumeration is with row-major ordering. -/// -/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2 -/// positions written to are (1,3) and (1,4), which have linearized indices 8 -/// and 9. So [8,9] is returned. -/// -/// The length of the returned vector is equal to the number of elements in -/// the shape `small` (i.e. the product of dimensions of `small`). -SmallVector static getStridedSliceInsertionIndices( - ArrayRef small, ArrayRef large, - ArrayRef offsets) { - - // Example of alignment between, `large`, `small` and `offsets`: - // large = 4, 5, 6, 7, 8 - // small = 1, 6, 7, 8 - // offsets = 2, 3, 0 - // - // `offsets` has implicit trailing 0s, `small` has implicit leading 1s. - assert((large.size() >= small.size()) && - "rank of 'large' cannot be lower than rank of 'small'"); - assert((large.size() >= offsets.size()) && - "rank of 'large' cannot be lower than the number of offsets"); - unsigned delta = large.size() - small.size(); - unsigned nOffsets = offsets.size(); - auto getSmall = [&](int64_t i) -> int64_t { - return i >= delta ? small[i - delta] : 1; - }; - auto getOffset = [&](int64_t i) -> int64_t { - return i < nOffsets ? offsets[i] : 0; - }; - - // Using 2 vectors of indices, at each iteration populate the updated set of - // indices based on the old set of indices, and the size of the small vector - // in the current iteration. - SmallVector indices{0}; - int64_t stride = 1; - for (int i = large.size() - 1; i >= 0; --i) { - int64_t currentSize = indices.size(); - int64_t smallSize = getSmall(i); - int64_t nextSize = currentSize * smallSize; - SmallVector nextIndices(nextSize); - int64_t *base = nextIndices.begin(); - int64_t offset = getOffset(i) * stride; - for (int j = 0; j < smallSize; ++j) { - for (int k = 0; k < currentSize; ++k) { - base[k] = indices[k] + offset; - } - offset += stride; - base += currentSize; - } - stride *= large[i]; - indices = std::move(nextIndices); - } - return indices; -} - /// This pattern converts a vector.extract_strided_slice operation into a /// vector.shuffle operation that has a rank-1 (linearized) operand and result. /// @@ -231,30 +147,23 @@ struct LinearizeVectorExtractStridedSlice final // Expect a legalization failure if the strides are not all 1 (if ever the // verifier for extract_strided_slice allows non-1 strides). - if (!stridesAllOne(extractStridedSliceOp)) { + if (extractStridedSliceOp.hasNonUnitStrides()) { return rewriter.notifyMatchFailure( extractStridedSliceOp, "extract_strided_slice with strides != 1 not supported"); } - FailureOr> offsets = - intsFromArrayAttr(extractStridedSliceOp.getOffsets()); - if (failed(offsets)) { + FailureOr> indices = + extractStridedSliceOp.getLinearIndices(); + if (failed(indices)) { return rewriter.notifyMatchFailure(extractStridedSliceOp, - "failed to get integer offsets"); + "failed to get indices"); } - ArrayRef inputShape = - extractStridedSliceOp.getSourceVectorType().getShape(); - - ArrayRef outputShape = extractStridedSliceOp.getType().getShape(); - - SmallVector indices = getStridedSliceInsertionIndices( - outputShape, inputShape, offsets.value()); - Value srcVector = adaptor.getVector(); - rewriter.replaceOpWithNewOp( - extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices); + rewriter.replaceOpWithNewOp(extractStridedSliceOp, + flatOutputType, srcVector, + srcVector, indices.value()); return success(); } }; @@ -298,31 +207,24 @@ struct LinearizeVectorInsertStridedSlice final // Expect a legalization failure if the strides are not all 1 (if ever the // verifier for insert_strided_slice allows non-1 strides). - if (!stridesAllOne(insertStridedSliceOp)) { + if (insertStridedSliceOp.hasNonUnitStrides()) { return rewriter.notifyMatchFailure( insertStridedSliceOp, "insert_strided_slice with strides != 1 not supported"); } - VectorType inputType = insertStridedSliceOp.getValueToStore().getType(); - ArrayRef inputShape = inputType.getShape(); - VectorType outputType = insertStridedSliceOp.getType(); - ArrayRef outputShape = outputType.getShape(); int64_t nOutputElements = outputType.getNumElements(); - FailureOr> offsets = - intsFromArrayAttr(insertStridedSliceOp.getOffsets()); - if (failed(offsets)) { + FailureOr> sliceIndices = + insertStridedSliceOp.getLinearIndices(); + if (failed(sliceIndices)) return rewriter.notifyMatchFailure(insertStridedSliceOp, - "failed to get integer offsets"); - } - SmallVector sliceIndices = getStridedSliceInsertionIndices( - inputShape, outputShape, offsets.value()); + "failed to get indices"); SmallVector indices(nOutputElements); std::iota(indices.begin(), indices.end(), 0); - for (auto [index, sliceIndex] : llvm::enumerate(sliceIndices)) { + for (auto [index, sliceIndex] : llvm::enumerate(sliceIndices.value())) { indices[sliceIndex] = index + nOutputElements; } From 1cc6345d5cde6972b3bd1d1708f0c4f152af349b Mon Sep 17 00:00:00 2001 From: James Newling Date: Thu, 22 May 2025 12:06:23 -0700 Subject: [PATCH 2/2] comment improvement --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index f8412863b18c9..481523ff10c3f 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1114,7 +1114,7 @@ def Vector_InsertStridedSliceOp : return ::llvm::cast(attr).getInt() != 1; }); } - // \return The indices in dest that the values are inserted to. + // \return The indices in `dest` where values are stored. FailureOr> getLinearIndices(); }]; @@ -1256,7 +1256,7 @@ def Vector_ExtractStridedSliceOp : return ::llvm::cast(attr).getInt() != 1; }); } - // \return The indices in source that the values are taken from. + // \return The indices in `source` where values are extracted. FailureOr> getLinearIndices(); }]; let hasCanonicalizer = 1;