Skip to content

Commit 9e349f3

Browse files
committed
first
1 parent cf9b4d1 commit 9e349f3

File tree

6 files changed

+207
-4
lines changed

6 files changed

+207
-4
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ namespace vector {
5050
class ContractionOp;
5151
class TransferReadOp;
5252
class TransferWriteOp;
53+
class TransposeOp;
5354
class VectorDialect;
5455

5556
namespace detail {
@@ -171,6 +172,12 @@ SmallVector<Value> getAsValues(OpBuilder &builder, Location loc,
171172
/// `std::nullopt`.
172173
std::optional<int64_t> getConstantVscaleMultiplier(Value value);
173174

175+
/// Return true if `transpose` does not permute a pair of non-unit dims.
176+
/// By `order preserving` we mean that the flattened versions of the input and
177+
/// output vectors are (numerically) identical. In other words `transpose` is
178+
/// effectively a shape cast.
179+
bool isOrderPreserving(TransposeOp transpose);
180+
174181
//===----------------------------------------------------------------------===//
175182
// Vector Masking Utilities
176183
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,26 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
406406
void populateVectorTransposeNarrowTypeRewritePatterns(
407407
RewritePatternSet &patterns, PatternBenefit benefit = 1);
408408

409+
/// Add patterns that convert operations that are semantically equivalent to
410+
/// shape_cast, to shape_cast. Currently this includes patterns for converting
411+
/// transpose, extract and broadcast to shape_cast. Examples that will be
412+
/// converted to shape_cast are:
413+
///
414+
/// ```
415+
/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
416+
/// %1 = vector.transpose %arg1, [1, 0] : vector<2x1xi8> to vector<1x2xi8>
417+
/// %2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>
418+
/// ```
419+
///
420+
/// Note that there is no pattern for vector.extract_strided_slice, because the
421+
/// only extract_strided_slice that is semantically equivalent to shape_cast is
422+
/// one that has idential input and output shapes, which is already folded.
423+
///
424+
/// These patterns can be useful to expose more folding opportunities by
425+
/// creating pairs of shape_casts that cancel.
426+
void populateConvertToShapeCastPatterns(RewritePatternSet &,
427+
PatternBenefit = 1);
428+
409429
/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
410430
/// This registers (1) which operations are legal and hence should not be
411431
/// linearized, (2) what converted types are (rank-1 vectors) and how to

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5574,13 +5574,12 @@ LogicalResult ShapeCastOp::verify() {
55745574
return success();
55755575
}
55765576

5577-
namespace {
55785577

55795578
/// Return true if `transpose` does not permute a pair of non-unit dims.
55805579
/// By `order preserving` we mean that the flattened versions of the input and
55815580
/// output vectors are (numerically) identical. In other words `transpose` is
55825581
/// effectively a shape cast.
5583-
bool isOrderPreserving(TransposeOp transpose) {
5582+
bool mlir::vector::isOrderPreserving(TransposeOp transpose) {
55845583
ArrayRef<int64_t> permutation = transpose.getPermutation();
55855584
VectorType sourceType = transpose.getSourceVectorType();
55865585
ArrayRef<int64_t> inShape = sourceType.getShape();
@@ -5600,8 +5599,6 @@ bool isOrderPreserving(TransposeOp transpose) {
56005599
return true;
56015600
}
56025601

5603-
} // namespace
5604-
56055602
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56065603

56075604
VectorType resultType = getType();

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2182,6 +2182,91 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
21822182
}
21832183
};
21842184

2185+
/// For example,
2186+
/// ```
2187+
/// %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to
2188+
/// vector<2x2x1xf32>
2189+
/// ```
2190+
/// becomes
2191+
/// ```
2192+
/// %0 = vector.shape_cast %arg0 : vector<2x1x2xf32> to vector<2x2x1xf32>
2193+
/// ```
2194+
struct TransposeToShapeCast final
2195+
: public OpRewritePattern<vector::TransposeOp> {
2196+
using OpRewritePattern::OpRewritePattern;
2197+
LogicalResult matchAndRewrite(vector::TransposeOp transpose,
2198+
PatternRewriter &rewriter) const override {
2199+
if (!isOrderPreserving(transpose)) {
2200+
return rewriter.notifyMatchFailure(
2201+
transpose, "not order preserving, so not semantically a 'copy'");
2202+
}
2203+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
2204+
transpose, transpose.getType(), transpose.getVector());
2205+
return success();
2206+
}
2207+
};
2208+
2209+
/// For example,
2210+
/// ```
2211+
/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
2212+
/// ```
2213+
/// becomes
2214+
/// ```
2215+
/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
2216+
/// ```
2217+
struct BroadcastToShapeCast final
2218+
: public OpRewritePattern<vector::BroadcastOp> {
2219+
using OpRewritePattern::OpRewritePattern;
2220+
LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
2221+
PatternRewriter &rewriter) const override {
2222+
auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
2223+
if (!sourceType) {
2224+
return rewriter.notifyMatchFailure(
2225+
broadcast, "source is a scalar, shape_cast doesn't support scalar");
2226+
}
2227+
2228+
VectorType outType = broadcast.getType();
2229+
if (sourceType.getNumElements() != outType.getNumElements())
2230+
return failure();
2231+
2232+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
2233+
broadcast.getSource());
2234+
return success();
2235+
}
2236+
};
2237+
2238+
/// For example,
2239+
/// ```
2240+
/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
2241+
/// ```
2242+
/// becomes
2243+
/// ```
2244+
/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
2245+
/// ```
2246+
struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
2247+
using OpRewritePattern::OpRewritePattern;
2248+
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2249+
PatternRewriter &rewriter) const override {
2250+
VectorType sourceType = extractOp.getSourceVectorType();
2251+
VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2252+
if (!outType)
2253+
return failure();
2254+
2255+
// Negative values in `position` indicates poison, cannot convert to
2256+
// shape_cast
2257+
if (llvm::any_of(extractOp.getMixedPosition(),
2258+
[](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
2259+
return failure();
2260+
2261+
if (sourceType.getNumElements() != outType.getNumElements())
2262+
return failure();
2263+
2264+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
2265+
extractOp.getVector());
2266+
return success();
2267+
}
2268+
};
2269+
21852270
} // namespace
21862271

21872272
void mlir::vector::populateFoldArithExtensionPatterns(
@@ -2285,6 +2370,13 @@ void mlir::vector::populateElementwiseToVectorOpsPatterns(
22852370
patterns.getContext());
22862371
}
22872372

2373+
void mlir::vector::populateConvertToShapeCastPatterns(
2374+
RewritePatternSet &patterns, PatternBenefit benefit) {
2375+
patterns
2376+
.insert<TransposeToShapeCast, BroadcastToShapeCast, ExtractToShapeCast>(
2377+
patterns.getContext(), benefit);
2378+
}
2379+
22882380
//===----------------------------------------------------------------------===//
22892381
// TableGen'd enum attribute definitions
22902382
//===----------------------------------------------------------------------===//
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// RUN: mlir-opt %s -split-input-file -test-convert-to-shape-cast | FileCheck %s
2+
3+
4+
// CHECK-LABEL: @transpose_to_shape_cast
5+
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
6+
// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
7+
// CHECK-NEXT: return %[[SCAST]] : vector<2x2x1xf32>
8+
func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
9+
%0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
10+
return %0 : vector<2x2x1xf32>
11+
}
12+
13+
// -----
14+
15+
// CHECK-LABEL: @negative_transpose_to_shape_cast
16+
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
17+
// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1]
18+
// CHECK-NEXT: return %[[TRANSPOSE]] : vector<2x2x1xf32>
19+
func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
20+
%0 = vector.transpose %arg0, [2, 0, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
21+
return %0 : vector<2x2x1xf32>
22+
}
23+
24+
// -----
25+
26+
// CHECK-LABEL: @broadcast_to_shape_cast
27+
// CHECK-SAME: %[[ARG0:.*]]: vector<4xi8>
28+
// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
29+
// CHECK-NEXT: return %[[SCAST]] : vector<1x1x4xi8>
30+
func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> {
31+
%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
32+
return %0 : vector<1x1x4xi8>
33+
}
34+
35+
// -----
36+
37+
// CHECK-LABEL: @negative_broadcast_to_shape_cast
38+
// CHECK-NOT: shape_cast
39+
// CHECK: return
40+
func.func @negative_broadcast_to_shape_cast(%arg0 : vector<1x4xi8>) -> vector<2x3x4xi8> {
41+
%0 = vector.broadcast %arg0 : vector<1x4xi8> to vector<2x3x4xi8>
42+
return %0 : vector<2x3x4xi8>
43+
}
44+
45+
// -----
46+
47+
// CHECK-LABEL: @extract_to_shape_cast
48+
// CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32>
49+
// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
50+
// CHECK-NEXT: return %[[SCAST]] : vector<4xf32>
51+
func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> {
52+
%0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
53+
return %0 : vector<4xf32>
54+
}
55+
56+
// -----
57+
58+
// In this example, arg1 might be negative indicating poison.
59+
// CHECK-LABEL: @negative_extract_to_shape_cast
60+
// CHECK-NOT: shape_cast
61+
func.func @negative_extract_to_shape_cast(%arg0 : vector<1x4xf32>, %arg1 : index) -> vector<4xf32> {
62+
%0 = vector.extract %arg0[%arg1] : vector<4xf32> from vector<1x4xf32>
63+
return %0 : vector<4xf32>
64+
}
65+

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,6 +1022,26 @@ struct TestEliminateVectorMasks
10221022
VscaleRange{vscaleMin, vscaleMax});
10231023
}
10241024
};
1025+
1026+
struct TestConvertToShapeCast
1027+
: public PassWrapper<TestConvertToShapeCast, OperationPass<func::FuncOp>> {
1028+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertToShapeCast)
1029+
1030+
TestConvertToShapeCast() = default;
1031+
1032+
StringRef getArgument() const final { return "test-convert-to-shape-cast"; }
1033+
StringRef getDescription() const final {
1034+
return "Test conversion to shape_cast of semantically equivalent ops";
1035+
}
1036+
void getDependentDialects(DialectRegistry &registry) const override {
1037+
registry.insert<vector::VectorDialect>();
1038+
}
1039+
void runOnOperation() override {
1040+
RewritePatternSet patterns(&getContext());
1041+
populateConvertToShapeCastPatterns(patterns);
1042+
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
1043+
}
1044+
};
10251045
} // namespace
10261046

10271047
namespace mlir {
@@ -1072,6 +1092,8 @@ void registerTestVectorLowerings() {
10721092
PassRegistration<vendor::TestVectorBitWidthLinearize>();
10731093

10741094
PassRegistration<TestEliminateVectorMasks>();
1095+
1096+
PassRegistration<TestConvertToShapeCast>();
10751097
}
10761098
} // namespace test
10771099
} // namespace mlir

0 commit comments

Comments
 (0)