Skip to content
Merged
121 changes: 121 additions & 0 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Support/LLVM.h"
Expand Down Expand Up @@ -2387,9 +2388,129 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
return success();
}

/// Rewrite from_elements on multiple scalar extracts as a shape_cast
/// on a single extract. Example:
/// %0 = vector.extract %source[0, 0] : i8 from vector<2x2xi8>
/// %1 = vector.extract %source[0, 1] : i8 from vector<2x2xi8>
/// %2 = vector.from_elements %0, %1 : vector<2xi8>
///
/// becomes
/// %1 = vector.extract %source[0] : vector<1x2xi8> from vector<2x2xi8>
/// %2 = vector.shape_cast %1 : vector<1x2xi8> to vector<2xi8>
///
/// The requirements for this to be valid are
///
/// i) The elements are extracted from the same vector (%source).
///
/// ii) The elements form a suffix of %source. Specifically, the number
/// of elements is the same as the product of the last N dimension sizes
/// of %source, for some N.
///
/// iii) The elements are extracted contiguously in ascending order.

class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {

using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(FromElementsOp fromElements,
PatternRewriter &rewriter) const override {

// Handled by `rewriteFromElementsAsSplat`
if (fromElements.getType().getNumElements() == 1)
return failure();

// The common source that all elements are extracted from, if one exists.
TypedValue<VectorType> source;
// The position of the combined extract operation, if one is created.
ArrayRef<int64_t> combinedPosition;
// The expected index of extraction of the current element in the loop, if
// elements are extracted contiguously in ascending order.
SmallVector<int64_t> expectedPosition;

for (auto [insertIndex, element] :
llvm::enumerate(fromElements.getElements())) {

// Check that the element is from a vector.extract operation.
auto extractOp =
dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
if (!extractOp) {
return rewriter.notifyMatchFailure(fromElements,
"element not from vector.extract");
}

// Check condition (i) by checking that all elements have the same source
// as the first element.
if (insertIndex == 0) {
source = extractOp.getVector();
} else if (extractOp.getVector() != source) {
return rewriter.notifyMatchFailure(fromElements,
"element from different vector");
}

ArrayRef<int64_t> position = extractOp.getStaticPosition();
int64_t rank = position.size();
assert(rank == source.getType().getRank() &&
"scalar extract must have full rank position");

// Check condition (ii) by checking that the position that the first
// element is extracted from has sufficient trailing 0s. For example, in
//
// %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8>
// [...]
// %elms = vector.from_elements %elm0, [...] : vector<12xi8>
//
// The 2 trailing 0s in the position of extraction of %elm0 cover 3*4 = 12
// elements, which is the number of elements of %n, so this is valid.
if (insertIndex == 0) {
const int64_t numElms = fromElements.getType().getNumElements();
int64_t numSuffixElms = 1;
int64_t index = rank;
while (index > 0 && position[index - 1] == 0 &&
numSuffixElms < numElms) {
numSuffixElms *= source.getType().getDimSize(index - 1);
--index;
}
if (numSuffixElms != numElms) {
return rewriter.notifyMatchFailure(
fromElements, "elements do not form a suffix of source");
}
expectedPosition = llvm::to_vector(position);
combinedPosition = position.drop_back(rank - index);
}

// Check condition (iii).
else if (expectedPosition != position) {
return rewriter.notifyMatchFailure(
fromElements, "elements not in ascending order (static order)");
}
increment(expectedPosition, source.getType().getShape());
}

auto extracted = rewriter.createOrFold<vector::ExtractOp>(
fromElements.getLoc(), source, combinedPosition);

rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
fromElements, fromElements.getType(), extracted);

return success();
}

/// Increments n-D `indices` by 1 starting from the innermost dimension.
static void increment(MutableArrayRef<int64_t> indices,
ArrayRef<int64_t> shape) {
for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
indices[dim] += 1;
if (indices[dim] < shape[dim])
break;
indices[dim] = 0;
}
}
};

void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add(rewriteFromElementsAsSplat);
results.add<FromElementsToShapeCast>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
69 changes: 0 additions & 69 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2943,75 +2943,6 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,

// -----

// CHECK-LABEL: func @extract_scalar_from_from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
// Extract from 0D.
%0 = vector.from_elements %a : vector<f32>
%1 = vector.extract %0[] : f32 from vector<f32>

// Extract from 1D.
%2 = vector.from_elements %a : vector<1xf32>
%3 = vector.extract %2[0] : f32 from vector<1xf32>
%4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
%5 = vector.extract %4[4] : f32 from vector<5xf32>

// Extract from 2D.
%6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
%7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
%8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
%9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
%10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>

// CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]]
return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
}

// -----

// CHECK-LABEL: func @extract_1d_from_from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
%0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
// CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32>
%1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
// CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32>
%2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
// CHECK: return %[[splat1]], %[[splat2]]
return %1, %2 : vector<3xf32>, vector<3xf32>
}

// -----

// CHECK-LABEL: func @extract_2d_from_from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
%0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
// CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32>
%1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
// CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32>
%2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
// CHECK: return %[[splat1]], %[[splat2]]
return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
}

// -----

// CHECK-LABEL: func @from_elements_to_splat(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
// CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32>
%0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
// CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
%1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
// CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector<f32>
%2 = vector.from_elements %a : vector<f32>
// CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
}

// -----

// CHECK-LABEL: func @vector_insert_const_regression(
// CHECK: llvm.mlir.undef
// CHECK: vector.insert
Expand Down
Loading
Loading