-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mli][vector] canonicalize vector.from_elements from ascending extracts #139819
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
c252a3d
7f40da6
3ce0713
94c5d8c
28fcceb
e985a7e
74d74f9
c584ac8
e154b22
dd3e231
0c3e6a7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,6 +33,7 @@ | |
#include "mlir/IR/OpImplementation.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/IR/TypeUtilities.h" | ||
#include "mlir/IR/ValueRange.h" | ||
#include "mlir/Interfaces/SubsetOpInterface.h" | ||
#include "mlir/Interfaces/ValueBoundsOpInterface.h" | ||
#include "mlir/Support/LLVM.h" | ||
|
@@ -2385,9 +2386,98 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, | |
return success(); | ||
} | ||
|
||
/// Rewrite vector.from_elements as vector.shape_cast, if possible. | ||
/// | ||
/// Example: | ||
/// %0 = vector.extract %source[0, 0] : i8 from vector<1x2xi8> | ||
/// %1 = vector.extract %source[0, 1] : i8 from vector<1x2xi8> | ||
/// %2 = vector.from_elements %0, %1 : vector<2xi8> | ||
/// | ||
/// becomes | ||
/// %2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8> | ||
/// | ||
/// The requirements for this to be valid are | ||
/// i) all elements are extracted from the same vector (source), | ||
/// ii) source and from_elements result have the same number of elements, | ||
/// iii) the elements are extracted in ascending order. | ||
/// | ||
/// It might be possible to rewrite vector.from_elements as a single | ||
/// vector.extract if (ii) is not satisifed, or in some cases as a | ||
/// a single vector_extract_strided_slice if (ii) and (iii) are not satisfied, | ||
/// this is left for future consideration. | ||
class FromElementsToShapCast : public OpRewritePattern<FromElementsOp> { | ||
public: | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(FromElementsOp fromElements, | ||
PatternRewriter &rewriter) const override { | ||
|
||
mlir::OperandRange elements = fromElements.getElements(); | ||
assert(!elements.empty() && "must be at least 1 element"); | ||
Value firstElement = elements.front(); | ||
|
||
ExtractOp extractOp = | ||
dyn_cast_if_present<vector::ExtractOp>(firstElement.getDefiningOp()); | ||
if (!extractOp) { | ||
return rewriter.notifyMatchFailure( | ||
fromElements, "first element not from vector.extract"); | ||
} | ||
|
||
VectorType sourceType = extractOp.getSourceVectorType(); | ||
Value source = extractOp.getVector(); | ||
|
||
// Check condition (ii). | ||
if (static_cast<size_t>(sourceType.getNumElements()) != elements.size()) { | ||
return rewriter.notifyMatchFailure(fromElements, | ||
"number of elements differ"); | ||
} | ||
|
||
|
||
for (auto [indexMinusOne, element] : | ||
llvm::enumerate(elements.drop_front(1))) { | ||
|
||
extractOp = | ||
dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp()); | ||
if (!extractOp) { | ||
return rewriter.notifyMatchFailure(fromElements, | ||
"element not from vector.extract"); | ||
} | ||
Value currentSource = extractOp.getVector(); | ||
// Check condition (i). | ||
if (currentSource != source) { | ||
return rewriter.notifyMatchFailure(fromElements, | ||
"element from different vector"); | ||
} | ||
|
||
|
||
ArrayRef<int64_t> position = extractOp.getStaticPosition(); | ||
assert(position.size() == static_cast<size_t>(sourceType.getRank()) && | ||
"scalar extract must have full rank position"); | ||
int64_t stride{1}; | ||
int64_t offset{0}; | ||
for (auto [pos, size] : llvm::zip(llvm::reverse(position), | ||
llvm::reverse(sourceType.getShape()))) { | ||
if (pos == ShapedType::kDynamic) { | ||
return rewriter.notifyMatchFailure( | ||
fromElements, "elements not in ascending order (dynamic order)"); | ||
} | ||
offset += pos * stride; | ||
stride *= size; | ||
} | ||
// Check condition (iii). | ||
if (offset != static_cast<int64_t>(indexMinusOne + 1)) { | ||
return rewriter.notifyMatchFailure( | ||
fromElements, "elements not in ascending order (static order)"); | ||
} | ||
} | ||
|
||
rewriter.replaceOpWithNewOp<ShapeCastOp>(fromElements, | ||
fromElements.getType(), source); | ||
return success(); | ||
} | ||
}; | ||
|
||
void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, | ||
MLIRContext *context) { | ||
results.add(rewriteFromElementsAsSplat); | ||
results.add<FromElementsToShapCast>(context); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s | ||
|
||
// This file contains some tests of folding/canonicalizing vector.from_elements | ||
|
||
///===----------------------------------------------===// | ||
/// Tests of `rewriteFromElementsAsSplat` | ||
///===----------------------------------------------===// | ||
Comment on lines
+5
to
+7
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This section was copied, right? Could you add a note in the summary so that it's easy to track the history? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, copied. I added a comment to the PR summary, I assume that's where you meant? |
||
|
||
// CHECK-LABEL: func @extract_scalar_from_from_elements( | ||
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) | ||
func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) { | ||
// Extract from 0D. | ||
%0 = vector.from_elements %a : vector<f32> | ||
%1 = vector.extract %0[] : f32 from vector<f32> | ||
|
||
// Extract from 1D. | ||
%2 = vector.from_elements %a : vector<1xf32> | ||
%3 = vector.extract %2[0] : f32 from vector<1xf32> | ||
%4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32> | ||
%5 = vector.extract %4[4] : f32 from vector<5xf32> | ||
|
||
// Extract from 2D. | ||
%6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32> | ||
%7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32> | ||
%8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32> | ||
%9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32> | ||
%10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32> | ||
|
||
// CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]] | ||
return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32 | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func @extract_1d_from_from_elements( | ||
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) | ||
func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) { | ||
%0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32> | ||
// CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32> | ||
%1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32> | ||
// CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32> | ||
%2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32> | ||
// CHECK: return %[[splat1]], %[[splat2]] | ||
return %1, %2 : vector<3xf32>, vector<3xf32> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func @extract_2d_from_from_elements( | ||
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) | ||
func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) { | ||
%0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32> | ||
// CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32> | ||
%1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32> | ||
// CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32> | ||
%2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32> | ||
// CHECK: return %[[splat1]], %[[splat2]] | ||
return %1, %2 : vector<2x2xf32>, vector<2x2xf32> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func @from_elements_to_splat( | ||
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) | ||
func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) { | ||
// CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32> | ||
%0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32> | ||
// CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32> | ||
%1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32> | ||
// CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector<f32> | ||
%2 = vector.from_elements %a : vector<f32> | ||
// CHECK: return %[[splat]], %[[from_el]], %[[splat2]] | ||
return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32> | ||
} | ||
|
||
// ----- | ||
|
||
///===----------------------------------------------===// | ||
/// Tests of `FromElementsToShapeCast` | ||
///===----------------------------------------------===// | ||
|
||
// CHECK-LABEL: func @to_shape_cast_rank2_to_rank1( | ||
// CHECK-SAME: %[[a:.*]]: vector<1x2xi8>) | ||
// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[a]] : vector<1x2xi8> to vector<2xi8> | ||
// CHECK: return %[[shape_cast]] : vector<2xi8> | ||
|
||
func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> { | ||
%0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8> | ||
%1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8> | ||
%4 = vector.from_elements %0, %1 : vector<2xi8> | ||
return %4 : vector<2xi8> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func @to_shape_cast_rank1_to_rank3( | ||
// CHECK-SAME: %[[a:.*]]: vector<8xi8>) | ||
// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[a]] : vector<8xi8> to vector<2x2x2xi8> | ||
// CHECK: return %[[shape_cast]] : vector<2x2x2xi8> | ||
func.func @to_shape_cast_rank1_to_rank3(%arg0: vector<8xi8>) -> vector<2x2x2xi8> { | ||
%0 = vector.extract %arg0[0] : i8 from vector<8xi8> | ||
%1 = vector.extract %arg0[1] : i8 from vector<8xi8> | ||
%2 = vector.extract %arg0[2] : i8 from vector<8xi8> | ||
%3 = vector.extract %arg0[3] : i8 from vector<8xi8> | ||
%4 = vector.extract %arg0[4] : i8 from vector<8xi8> | ||
%5 = vector.extract %arg0[5] : i8 from vector<8xi8> | ||
%6 = vector.extract %arg0[6] : i8 from vector<8xi8> | ||
%7 = vector.extract %arg0[7] : i8 from vector<8xi8> | ||
%8 = vector.from_elements %0, %1, %2, %3, %4, %5, %6, %7 : vector<2x2x2xi8> | ||
return %8 : vector<2x2x2xi8> | ||
} | ||
|
||
// ----- | ||
|
||
// The extracted elements are recombined into a single vector, but in a new order. | ||
// CHECK-LABEL: func @negative_nonascending_order( | ||
// CHECK-NOT: shape_cast | ||
func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> { | ||
%0 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8> | ||
%1 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8> | ||
%2 = vector.from_elements %0, %1 : vector<2xi8> | ||
return %2 : vector<2xi8> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func @negative_nonstatic_extract( | ||
// CHECK-NOT: shape_cast | ||
func.func @negative_nonstatic_extract(%arg0: vector<1x2xi8>, %i0 : index, %i1 : index) -> vector<2xi8> { | ||
%0 = vector.extract %arg0[0, %i0] : i8 from vector<1x2xi8> | ||
%1 = vector.extract %arg0[0, %i1] : i8 from vector<1x2xi8> | ||
%2 = vector.from_elements %0, %1 : vector<2xi8> | ||
return %2 : vector<2xi8> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func @negative_different_sources( | ||
// CHECK-NOT: shape_cast | ||
func.func @negative_different_sources(%arg0: vector<1x2xi8>, %arg1: vector<1x2xi8>) -> vector<2xi8> { | ||
%0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8> | ||
%1 = vector.extract %arg1[0, 1] : i8 from vector<1x2xi8> | ||
%2 = vector.from_elements %0, %1 : vector<2xi8> | ||
return %2 : vector<2xi8> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func @negative_source_too_large( | ||
// CHECK-NOT: shape_cast | ||
func.func @negative_source_too_large(%arg0: vector<1x3xi8>) -> vector<2xi8> { | ||
%0 = vector.extract %arg0[0, 0] : i8 from vector<1x3xi8> | ||
%1 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8> | ||
%2 = vector.from_elements %0, %1 : vector<2xi8> | ||
return %2 : vector<2xi8> | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[ultra nit] "if possible" is implicit and "less is more" :)