Skip to content
Merged
97 changes: 97 additions & 0 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Support/LLVM.h"
Expand Down Expand Up @@ -2385,9 +2386,105 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
return success();
}

/// Rewrite vector.from_elements as vector.shape_cast, if possible.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[ultra nit] "if possible" is implicit and "less is more" :)

Suggested change
/// Rewrite vector.from_elements as vector.shape_cast, if possible.
/// Rewrite vector.from_elements as vector.shape_cast.

///
/// Example:
/// %0 = vector.extract %source[0, 0] : i8 from vector<1x2xi8>
/// %1 = vector.extract %source[0, 1] : i8 from vector<1x2xi8>
/// %2 = vector.from_elements %0, %1 : vector<2xi8>
///
/// becomes
/// %2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8>
///
/// The requirements for this to be valid are
/// i) source and from_elements result have the same number of elements,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Otherwise it's not clear what source and from_elements are. Perhaps there's better way to clarify 🤔

Suggested change
/// i) source and from_elements result have the same number of elements,
/// i) vector.extract and vector.from_elements result have the same number of elements,

/// ii) all elements are extracted from the same vector (%source),
/// iii) the elements are extracted in ascending order.
///
/// It might be possible to rewrite vector.from_elements as a single
/// vector.extract if (i) is not satisifed, or in some cases as a
/// a single vector_extract_strided_slice if (i) and (iii) are not satisfied,
/// this is left for future consideration.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know we already have quite a few TODOs/FIXMEs that basically mean "let’s look at this later." But the phrasing "It might be possible…” feels particularly vague here - I’d suggest omitting it unless we can be more specific.

If we do want to leave a note, maybe something like:

“Consider extending to use a single vector.extract when (i) does not hold.”

Also, just a general thought: extending this pattern could quickly become quite complex. If we're seeing bad code that would benefit from such a complicated rewrite, it might be worth checking whether the producer of that code could be improved instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I'll spend a bit of time trying to canonicalize directly to vector.extract, I don't think it'll be significantly more complex.

class FromElementsToShapCast : public OpRewritePattern<FromElementsOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(FromElementsOp fromElements,
PatternRewriter &rewriter) const override {

// The source of the first element. This is initialized in the first
// iteration of the loop over elements.
TypedValue<VectorType> firstElementSource;

for (auto [insertIndex, element] :
llvm::enumerate(fromElements.getElements())) {

// Check that the element is from a vector.extract operation.
auto extractOp =
dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
if (!extractOp) {
return rewriter.notifyMatchFailure(fromElements,
"element not from vector.extract");
}

// Check condition (i) on the first element. As we will check that all
// elements have the same source, we don't need to check condition (i) for
// any other elements.
if (insertIndex == 0) {
firstElementSource = extractOp.getVector();
if (static_cast<size_t>(
firstElementSource.getType().getNumElements()) !=
fromElements.getType().getNumElements()) {
return rewriter.notifyMatchFailure(fromElements,
"number of elements differ");
}
}

// Check condition (ii), by checking that all elements have same source as
// the first element.
Value currentSource = extractOp.getVector();
if (currentSource != firstElementSource) {
return rewriter.notifyMatchFailure(fromElements,
"element from different vector");
}

// Check condition (iii).
// First, get the index that the element is extracted from.
int64_t extractIndex{0};
int64_t stride{1};
ArrayRef<int64_t> position = extractOp.getStaticPosition();
assert(position.size() ==
static_cast<size_t>(firstElementSource.getType().getRank()) &&
"scalar extract must have full rank position");
for (auto [pos, size] :
llvm::zip(llvm::reverse(position),
llvm::reverse(firstElementSource.getType().getShape()))) {
if (pos == ShapedType::kDynamic) {
return rewriter.notifyMatchFailure(
fromElements, "elements not in ascending order (dynamic order)");
}
extractIndex += pos * stride;
stride *= size;
}

// Second, check that the index of extraction from source and insertion in
// from_elements are the same.
if (extractIndex != static_cast<int64_t>(insertIndex)) {
return rewriter.notifyMatchFailure(
fromElements, "elements not in ascending order (static order)");
}
}

rewriter.replaceOpWithNewOp<ShapeCastOp>(
fromElements, fromElements.getType(), firstElementSource);
return success();
}
};

void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add(rewriteFromElementsAsSplat);
results.add<FromElementsToShapCast>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
69 changes: 0 additions & 69 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2952,75 +2952,6 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,

// -----

// CHECK-LABEL: func @extract_scalar_from_from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
// Extract from 0D.
%0 = vector.from_elements %a : vector<f32>
%1 = vector.extract %0[] : f32 from vector<f32>

// Extract from 1D.
%2 = vector.from_elements %a : vector<1xf32>
%3 = vector.extract %2[0] : f32 from vector<1xf32>
%4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
%5 = vector.extract %4[4] : f32 from vector<5xf32>

// Extract from 2D.
%6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
%7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
%8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
%9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
%10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>

// CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]]
return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
}

// -----

// CHECK-LABEL: func @extract_1d_from_from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
%0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
// CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32>
%1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
// CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32>
%2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
// CHECK: return %[[splat1]], %[[splat2]]
return %1, %2 : vector<3xf32>, vector<3xf32>
}

// -----

// CHECK-LABEL: func @extract_2d_from_from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
%0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
// CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32>
%1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
// CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32>
%2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
// CHECK: return %[[splat1]], %[[splat2]]
return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
}

// -----

// CHECK-LABEL: func @from_elements_to_splat(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
// CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32>
%0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
// CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
%1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
// CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector<f32>
%2 = vector.from_elements %a : vector<f32>
// CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
}

// -----

// CHECK-LABEL: func @vector_insert_const_regression(
// CHECK: llvm.mlir.undef
// CHECK: vector.insert
Expand Down
169 changes: 169 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s

// This file contains some tests of folding/canonicalizing vector.from_elements

///===----------------------------------------------===//
/// Tests of `rewriteFromElementsAsSplat`
///===----------------------------------------------===//
Comment on lines +5 to +7
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section was copied, right? Could you add a note in the summary so that it's easy to track the history?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, copied. I added a comment to the PR summary, I assume that's where you meant?


// CHECK-LABEL: func @extract_scalar_from_from_elements(
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
// Extract from 0D.
%0 = vector.from_elements %a : vector<f32>
%1 = vector.extract %0[] : f32 from vector<f32>

// Extract from 1D.
%2 = vector.from_elements %a : vector<1xf32>
%3 = vector.extract %2[0] : f32 from vector<1xf32>
%4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
%5 = vector.extract %4[4] : f32 from vector<5xf32>

// Extract from 2D.
%6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
%7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
%8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
%9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
%10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>

// CHECK: return %[[A]], %[[A]], %[[B]], %[[A]], %[[A]], %[[B]], %[[B]]
return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
}

// -----

// CHECK-LABEL: func @extract_1d_from_from_elements(
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
%0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
// CHECK: %[[SPLAT1:.*]] = vector.splat %[[A]] : vector<3xf32>
%1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
// CHECK: %[[SPLAT2:.*]] = vector.splat %[[B]] : vector<3xf32>
%2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
// CHECK: return %[[SPLAT1]], %[[SPLAT2]]
return %1, %2 : vector<3xf32>, vector<3xf32>
}

// -----

// CHECK-LABEL: func @extract_2d_from_from_elements(
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
%0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
// CHECK: %[[SPLAT1:.*]] = vector.from_elements %[[A]], %[[A]], %[[A]], %[[B]] : vector<2x2xf32>
%1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
// CHECK: %[[SPLAT2:.*]] = vector.from_elements %[[B]], %[[B]], %[[B]], %[[A]] : vector<2x2xf32>
%2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
// CHECK: return %[[SPLAT1]], %[[SPLAT2]]
return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
}

// -----

// CHECK-LABEL: func @from_elements_to_splat(
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
// CHECK: %[[SPLAT:.*]] = vector.splat %[[A]] : vector<2x3xf32>
%0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
// CHECK: %[[FROM_EL:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
%1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
// CHECK: %[[SPLAT2:.*]] = vector.splat %[[A]] : vector<f32>
%2 = vector.from_elements %a : vector<f32>
// CHECK: return %[[SPLAT]], %[[FROM_EL]], %[[SPLAT2]]
return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
}

// -----

///===----------------------------------------------===//
/// Tests of `FromElementsToShapeCast`
///===----------------------------------------------===//

// CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
// CHECK-SAME: %[[A:.*]]: vector<1x2xi8>)
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8>
// CHECK: return %[[SHAPE_CAST]] : vector<2xi8>
func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
%0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
%1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
%4 = vector.from_elements %0, %1 : vector<2xi8>
return %4 : vector<2xi8>
}

// -----

// CHECK-LABEL: func @to_shape_cast_rank1_to_rank3(
// CHECK-SAME: %[[A:.*]]: vector<8xi8>)
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[A]] : vector<8xi8> to vector<2x2x2xi8>
// CHECK: return %[[SHAPE_CAST]] : vector<2x2x2xi8>
func.func @to_shape_cast_rank1_to_rank3(%arg0: vector<8xi8>) -> vector<2x2x2xi8> {
%0 = vector.extract %arg0[0] : i8 from vector<8xi8>
%1 = vector.extract %arg0[1] : i8 from vector<8xi8>
%2 = vector.extract %arg0[2] : i8 from vector<8xi8>
%3 = vector.extract %arg0[3] : i8 from vector<8xi8>
%4 = vector.extract %arg0[4] : i8 from vector<8xi8>
%5 = vector.extract %arg0[5] : i8 from vector<8xi8>
%6 = vector.extract %arg0[6] : i8 from vector<8xi8>
%7 = vector.extract %arg0[7] : i8 from vector<8xi8>
%8 = vector.from_elements %0, %1, %2, %3, %4, %5, %6, %7 : vector<2x2x2xi8>
return %8 : vector<2x2x2xi8>
}

// -----

// The extracted elements are recombined into a single vector, but in a new order.
// CHECK-LABEL: func @negative_nonascending_order(
// CHECK-NOT: shape_cast
func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> {
%0 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
%1 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
%2 = vector.from_elements %0, %1 : vector<2xi8>
return %2 : vector<2xi8>
}

// -----

// CHECK-LABEL: func @negative_nonstatic_extract(
// CHECK-NOT: shape_cast
func.func @negative_nonstatic_extract(%arg0: vector<1x2xi8>, %i0 : index, %i1 : index) -> vector<2xi8> {
%0 = vector.extract %arg0[0, %i0] : i8 from vector<1x2xi8>
%1 = vector.extract %arg0[0, %i1] : i8 from vector<1x2xi8>
%2 = vector.from_elements %0, %1 : vector<2xi8>
return %2 : vector<2xi8>
}

// -----

// CHECK-LABEL: func @negative_different_sources(
// CHECK-NOT: shape_cast
func.func @negative_different_sources(%arg0: vector<1x2xi8>, %arg1: vector<1x2xi8>) -> vector<2xi8> {
%0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
%1 = vector.extract %arg1[0, 1] : i8 from vector<1x2xi8>
%2 = vector.from_elements %0, %1 : vector<2xi8>
return %2 : vector<2xi8>
}

// -----

// CHECK-LABEL: func @negative_source_too_large(
// CHECK-NOT: shape_cast
func.func @negative_source_too_large(%arg0: vector<1x3xi8>) -> vector<2xi8> {
%0 = vector.extract %arg0[0, 0] : i8 from vector<1x3xi8>
%1 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8>
%2 = vector.from_elements %0, %1 : vector<2xi8>
return %2 : vector<2xi8>
}

// -----

// The inserted elements are are a subset of the extracted elements.
// [0, 1, 2] -> [1, 1, 2]
// CHECK-LABEL: func @negative_nobijection_order(
// CHECK-NOT: shape_cast
func.func @negative_nobijection_order(%arg0: vector<1x3xi8>) -> vector<3xi8> {
%0 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8>
%1 = vector.extract %arg0[0, 2] : i8 from vector<1x3xi8>
%2 = vector.from_elements %0, %0, %1 : vector<3xi8>
return %2 : vector<3xi8>
}