-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[vector][mlir] Canonicalize to shape_cast where possible #140583
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
d546ab3 to
29d41d8
Compare
29d41d8 to
f2e5417
Compare
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@banach-space I'm getting back to this PR. Peephole question: is this operation ok? i.e. is
vector.shape_cast %a vector<[4]x1xf32> to vector<1x[4]xf32>
an acceptable operation to have after running mlir-opt -arm-sme-vector-legalization -cse -canonicalize ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, yes. But I can't guarantee there's no logic that expects vector<[4]x1xf32> instead of vector<1x[4]xf32> ;-) If that's the case, we will fix it and I will be grateful for uncovering this :)
7bc5da0 to
e673522
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Author note: I've removed this, as now it happens in 2 steps during canonicalization. The first converts the Broadcast to a ShapeCast. The second combines the 2 ShapeCasts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Author note: I've removed this, as it now happens in 2 steps during canonicalization. The first (new) step is to rewrite the transpose as a shape_cast. The second step is to fold shape_cast(shape_cast) to shape_cast.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Author note: I've removed this pattern, as it is a special case of TransposeToShapeCast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Author note: removed these tests, as the pattern they are testing is removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we keep them? shouldn't they still be canonicalized?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll add them back, yes they're still canonicalized
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Author note: as the vector.transpose is canonicalized to a vector.shape_cast, the lowering test is now moved to shape_cast lowering
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-sme Author: James Newling (newling) ChangesDiscussions suggest that we should use shape_cast as a canonical form of broadcast/transpose/extract where possible (see #138777) For example these can all be expressed as shape casts: %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
%1 = vector.transpose %arg1, [1, 0] : vector<2x1xi8> to vector<1x2xi8>
%2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>This PR adds canonicalizes to convert the above 3 examples to shape_casts. I've added some more comments as review comments. I'm happy to split this PR up and add the new patterns separately. Patch is 41.81 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140583.diff 10 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 862ed7bae1fbb..08cc4af158e10 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2351,11 +2351,41 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
return success();
}
+/// BEFORE:
+/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ VectorType sourceType = extractOp.getSourceVectorType();
+ VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+ if (!outType)
+ return failure();
+
+ // Negative values in `position` indicates poison, which cannot be
+ // represented with a shape_cast
+ if (llvm::any_of(extractOp.getMixedPosition(),
+ [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
+ return failure();
+
+ if (sourceType.getNumElements() != outType.getNumElements())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
+ extractOp.getVector());
+ return success();
+ }
+};
+
} // namespace
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+ results
+ .add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
+ context);
results.add(foldExtractFromShapeCastToShapeCast);
results.add(foldExtractFromFromElements);
}
@@ -2867,13 +2897,36 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
return success();
}
};
+
+/// BEFORE:
+/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+struct BroadcastToShapeCast final
+ : public OpRewritePattern<vector::BroadcastOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
+ PatternRewriter &rewriter) const override {
+ auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
+ if (!sourceType) {
+ return rewriter.notifyMatchFailure(
+ broadcast, "source is a scalar, shape_cast doesn't support scalar");
+ }
+
+ VectorType outType = broadcast.getType();
+ if (sourceType.getNumElements() != outType.getNumElements())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
+ broadcast.getSource());
+ return success();
+ }
+};
} // namespace
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
- // calling `populateCastAwayVectorLeadingOneDimPatterns`
- results.add<BroadcastFolder>(context);
+ results.add<BroadcastFolder, BroadcastToShapeCast>(context);
}
//===----------------------------------------------------------------------===//
@@ -5991,10 +6044,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
}
};
-/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
-/// i) Y = ShapeCast(X), or
-/// ii) Y = Broadcast(X)
-/// If both (i) and (ii) are possible, (i) is chosen.
+/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
@@ -6009,22 +6059,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
bool srcIsScalar = !srcVectorType;
- // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
- // Example:
- // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
- // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
- // to
- // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
- if (srcVectorType) {
- if (srcVectorType.getNumElements() ==
- shapeCastOp.getResultVectorType().getNumElements()) {
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
- shapeCastOp, shapeCastOp.getResultVectorType(),
- broadcastOp.getSource());
- return success();
- }
- }
-
// Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
// Example
// %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
@@ -6233,7 +6267,7 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
// %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
// %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
//
- // Example of what NOT to fold:
+ // Example of what not to fold:
// %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
//
if (getSourceVectorType() == getResultVectorType() &&
@@ -6359,32 +6393,6 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
}
};
-/// Folds transpose(shape_cast) into a new shape_cast.
-class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(TransposeOp transposeOp,
- PatternRewriter &rewriter) const override {
- auto shapeCastOp =
- transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
- if (!shapeCastOp)
- return failure();
- if (!isOrderPreserving(transposeOp))
- return failure();
-
- VectorType resultType = transposeOp.getType();
-
- // We don't need to check isValidShapeCast at this point, because it is
- // guaranteed that merging the transpose into the the shape_cast is a valid
- // shape_cast, because the transpose just inserts/removes ones.
-
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
- shapeCastOp.getSource());
- return success();
- }
-};
-
/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
/// 'order preserving', where 'order preserving' means the flattened
/// inputs and outputs of the transpose have identical (numerical) values.
@@ -6480,12 +6488,35 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
}
};
+/// BEFORE:
+/// %0 = vector.transpose %arg0, [0, 2, 1] :
+/// vector<2x1x2xf32> to vector<2x2x1xf32>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 :
+/// vector<2x1x2xf32> to vector<2x2x1xf32>
+struct TransposeToShapeCast final
+ : public OpRewritePattern<vector::TransposeOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::TransposeOp transpose,
+ PatternRewriter &rewriter) const override {
+
+ if (!isOrderPreserving(transpose)) {
+ return rewriter.notifyMatchFailure(
+ transpose, "not order preserving, so not semantically a 'copy'");
+ }
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+ transpose, transpose.getType(), transpose.getVector());
+ return success();
+ }
+};
+
} // namespace
void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
- results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
- FoldTransposeSplat, FoldTransposeBroadcast>(context);
+ results.add<FoldTransposeBroadcast, FoldTransposeCreateMask,
+ FoldTransposeSplat, TransposeFolder, TransposeToShapeCast>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 732e316c93381..71410eda28297 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -11,7 +11,6 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -382,64 +381,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
vector::VectorTransposeLowering vectorTransposeLowering;
};
-/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
-/// to 2D vectors with at least one unit dim. For example:
-///
-/// Replace:
-/// vector.transpose %0, [1, 0] : vector<4x1xi32>> to
-/// vector<1x4xi32>
-/// with:
-/// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
-///
-/// Source with leading unit dim (inverse) is also replaced. Unit dim must
-/// be fixed. Non-unit dim can be scalable.
-///
-/// TODO: This pattern was introduced specifically to help lower scalable
-/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
-/// to cancel out) would be preferable:
-///
-/// BEFORE:
-/// %0 = some_op
-/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
-/// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-/// AFTER:
-/// %0 = some_op
-/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
-///
-/// Given the context above, we may want to consider (re-)moving this pattern
-/// at some later time. I am leaving it for now in case there are other users
-/// that I am not aware of.
-class Transpose2DWithUnitDimToShapeCast
- : public OpRewritePattern<vector::TransposeOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
- PatternBenefit benefit = 1)
- : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
-
- LogicalResult matchAndRewrite(vector::TransposeOp op,
- PatternRewriter &rewriter) const override {
- Value input = op.getVector();
- VectorType resType = op.getResultVectorType();
-
- // Set up convenience transposition table.
- ArrayRef<int64_t> transp = op.getPermutation();
-
- if (resType.getRank() == 2 &&
- ((resType.getShape().front() == 1 &&
- !resType.getScalableDims().front()) ||
- (resType.getShape().back() == 1 &&
- !resType.getScalableDims().back())) &&
- transp == ArrayRef<int64_t>({1, 0})) {
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
- return success();
- }
-
- return failure();
- }
-};
-
/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
/// If the strategy is Shuffle1D, it will be lowered to:
/// vector.shape_cast 2D -> 1D
@@ -511,8 +452,6 @@ class TransposeOp2DToShuffleLowering
void mlir::vector::populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns,
VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) {
- patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
- benefit);
patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
vectorTransposeLowering, patterns.getContext(), benefit);
}
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index 6cdf576272ebc..a9a2fdccdd82f 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -480,11 +480,11 @@ func.func @lift_illegal_transpose_to_memory_with_in_bounds_attr(%a: index, %b: i
// -----
-// The pass should do nothing (and not crash).
-// CHECK-LABEL: @illegal_transpose_no_defining_source_op
-func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32>
+// CHECK-LABEL: @transpose_no_defining_source_op
+func.func @transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32>
{
- // CHECK: vector.transpose
+ // CHECK: vector.shape_cast
+ // CHECK-SAME: vector<[4]x1xf32> to vector<1x[4]xf32>
%0 = vector.transpose %vec, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
return %0 : vector<1x[4]xf32>
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 65b73375831da..374c71c814e89 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -451,16 +451,25 @@ func.func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>,
// -----
// CHECK-LABEL: transpose_3D_identity
-// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
+// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
+// CHECK-NEXT: return [[ARG]]
func.func @transpose_3D_identity(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {
- // CHECK-NOT: transpose
%0 = vector.transpose %arg, [0, 1, 2] : vector<4x3x2xf32> to vector<4x3x2xf32>
- // CHECK-NEXT: return [[ARG]]
return %0 : vector<4x3x2xf32>
}
// -----
+// CHECK-LABEL: transpose_0D_identity
+// CHECK-SAME: ([[ARG:%.*]]: vector<i8>)
+// CHECK-NEXT: return [[ARG]]
+func.func @transpose_0D_identity(%arg : vector<i8>) -> vector<i8> {
+ %0 = vector.transpose %arg, [] : vector<i8> to vector<i8>
+ return %0 : vector<i8>
+}
+
+// -----
+
// CHECK-LABEL: transpose_2D_sequence
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>)
func.func @transpose_2D_sequence(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
@@ -753,12 +762,13 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
// -----
+
// CHECK-LABEL: negative_fold_extract_broadcast
-// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
-// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x2x4xf32>
+// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x2x4xf32>
func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> {
- %b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32>
- %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+ %b = vector.broadcast %a : vector<1x1xf32> to vector<1x2x4xf32>
+ %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x2x4xf32>
return %r : vector<4xf32>
}
@@ -797,8 +807,8 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
// rank(extract_output) < rank(broadcast_input)
func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
%idx0 : index, %idx1 : index) -> vector<4xf32> {
- %b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
- %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
+ %b = vector.broadcast %a : vector<2x4xf32> to vector<2x2x4xf32>
+ %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<2x2x4xf32>
return %r : vector<4xf32>
}
@@ -1033,30 +1043,6 @@ func.func @canonicalize_broadcast_shapecast_to_broadcast_scalar(%arg0: f32) -> v
// -----
-// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is.
-// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapcast
-// CHECK-NOT: vector.broadcast
-// CHECK: vector.shape_cast {{.+}} : vector<2xf32> to vector<1x2x1xf32>
-func.func @canonicalize_broadcast_shapecast_to_shapcast(%arg0 : vector<2xf32>) -> vector<1x2x1xf32> {
- %0 = vector.broadcast %arg0 : vector<2xf32> to vector<1x2xf32>
- %1 = vector.shape_cast %0 : vector<1x2xf32> to vector<1x2x1xf32>
- return %1 : vector<1x2x1xf32>
-}
-
-// -----
-
-// In this test, broadcast (1)->(1,1) and shape_cast (1)->(1,1) are both legal. shape_cast is chosen.
-// CHECK-LABEL: func @canonicalize_broadcast_shapecast_both_possible
-// CHECK-NOT: vector.broadcast
-// CHECK: vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32>
-func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) -> vector<1x1xf32> {
- %0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32>
- %1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32>
- return %1 : vector<1x1xf32>
-}
-
-// -----
-
// CHECK-LABEL: fold_vector_transfer_masks
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -1920,12 +1906,12 @@ func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
// -----
-// CHECK-LABEL: func @insert_extract_to_broadcast
+// CHECK-LABEL: func @insert_extract_to_shape_cast
// CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-// CHECK: %[[V0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
-// CHECK: %[[V1:.*]] = vector.broadcast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
+// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
+// CHECK: %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
// CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
-func.func @insert_extract_to_broadcast(%arg0 : vector<1x1x4xf32>,
+func.func @insert_extract_to_shape_cast(%arg0 : vector<1x1x4xf32>,
%arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
%0 = vector.extract %arg0[0, 0] : vector<4xf32> from vector<1x1x4xf32>
%1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
@@ -2277,7 +2263,7 @@ func.func @shuffle_1d_rhs_poison() -> vector<4xi32> {
// CHECK-LABEL: func @shuffle_canonicalize_0d
func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
- // CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
+ // CHECK: vector.shape_cast %{{.*}} : vector<i32> to vector<1xi32>
%shuffle = vector.shuffle %v0, %v1 [0] : vector<i32>, vector<i32>
return %shuffle : vector<1xi32>
}
@@ -2764,9 +2750,8 @@ func.func @transfer_read_from_rank_reducing_extract_slice(%src: tensor<1x8x8x8xf
// CHECK-LABEL: func.func @extract_from_broadcast
func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
%0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32>
-
- // CHECK-NEXT: %0 = vector.extract {{.*}}[0, 0] : vector<1xf32> from vector<1x1x1xf32>
- // CHECK-NEXT: return %0 : vector<1xf32>
+ // CHECK-NEXT: %[[RES:.*]] = vector.shape_cast{{.*}} vector<1x1x1xf32> to vector<1xf32>
+ // CHECK-NEXT: return %[[RES]] : vector<1xf32>
%1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32>
return %1: vector<1xf32>
}
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index fdab2a8918a2e..d5f96a8928770 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -81,8 +81,8 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
// CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
// CHECK-SAME: %[[A:.*]]: vector<1x2xi8>)
-// CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]][0] : vector<2xi8> from vector<1x2xi8>
-// CHECK: return %[[EXTRACT]] : vector<2xi8>
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8>
+// CHECK: return %[[SC]] : vector<2xi8>
func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
%0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
%1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
diff --git...
[truncated]
|
|
Hi @banach-space and @dcaballe, I've pulled this PR out of draft mode, so please feel free to comment on it whenever! |
dcaballe
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! LGTM in general. The only general comment is to make sure we don't reduce testing coverage. I think we should keep/update the tests even for those cases where the pattern is removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keep both tests, one with the original shape and one with the new ones?
Unrelated: it looks like we are missing a canonicalization patter here? This should be turned into a single vector.broadcast to vector<4xf32>?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keep both tests, one with the original shape and one with the new ones?
Makes sense, will do.
Unrelated: it looks like we are missing a canonicalization patter here? This should be turned into a single vector.broadcast to vector<4xf32>?
No because you can't broadcast <1x1xf32> to <4xf32> -- broadcasts can never reduce rank in Vector. FWIW slightly related to my comment here where this would be simpler if ops didn't do implicit shape casting. In this case if it was something like
%s = vector.shape_cast %a : vector<1x1xf32> to vector<1x1x1xf32>
%b = vector.broadcast %s : vector<1x1x1xf32> to vector<1x2x4xf32>
%r = vector.extract %b[0, 0] : vector<1x1x4xf32> from vector<1x2x4xf32>
%s = vector.shape_cast %r : vector<1x1x4> to vector<4>
ie if we constrained broadcasts and extracts to be rank retaining, then this would be canonicalized to
%s = vector.shape_cast %a : vector<1x1xf32> to vector<1x1x1xf32>
%b = vector.broadcast %s : vector<1x1x1xf32> to vector<1x1x4xf32>
%s = vector.shape_cast %b : vector<1x1x4> to vector<4>
which, if you have faith that the shape_casts will vanish at a later point, is simpler!
p.s. I plan to reply in #145740 later today
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we keep them? shouldn't they still be canonicalized?
|
Thanks! I run the SME e2e tests and all pass. I wasn't able to cherry-pick this in IREE though, getting weird compilation errors. Though upstream tests should be sufficient to surface all potential issues. @newling , why not name all "folding" patterns as |
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change shapes?
I'll give this a spin with IREE
Yes, I think so. Actually fe3933d made me wonder if we should split canonicalize.mlir into 2 files (the new one with name fold.mlir containing everything in canonicalize.mlir that only depends on 1-time folds). @banach-space and @dcaballe thanks for your feedback! Unfortunately I'm going to put this on hold again temporarily, as I've uncovered some other things which should be done before this. Moving back into draft mode, will ping when I think it's ready again. |
+1 |
1ff3399 to
92e809e
Compare
|
This PR is back, and ready for review! Let me summarize the previous concerns as this is quite old now: @dcaballe raised concerns about removing tests. I have reinstated all canonicalization tests. |
What is this PR achieving?
To me, this change isn't really about removing unit dims. Its about improving the state of In the distant past (~4yrs ago) we did use Without a crisp contract,
Context + previous PRsNote, James has already proposed/implemented a number of steps to improve
This PR simply takes things further (and was discussed before: #138777). I just wanted to highlight - there has been a number of PRs through which we concluded that this would be beneficial. Why take this approach?
Indeed! With multiple ways of expressing identical computation, lets just agree on one. I don't see any information being lost, so to me this a matter of simply agreeing on what the canonical form should be and then working with that. From #138777:
Yes! Clarifying this takes us to a cleaner, better defined state - that's beneficial for As a specific argument for this change, replacing: %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>with a Is this really a canonicalization?Yes. The original Ops are mechanically recoverable in every case, so I don't see any loss of information here.
I believe so, yes. Do we need an RFC?
We are discussing fine implementation details within LGTM?Yes |
|
I don't mind which way this PR goes. The original goal of the PR was to generate better code in a simpler way, not to define a canonical form. Converting to shape_cast achieves the original goal, but we don't have to say that shape_cast is a canonical form. As already mentioned in previous comments, technically it is a canonicalization if we want it to be, but until we formalize a canonical form (ideally with an energy function with guarantees of converging to a fixed point, and clear documentation and motivation in a single obvious place) maybe we're better off leaving things that don't directly reduce op count or provide some other obvious generic improvement, out. As someone who's tried to unravel complex decisions make 4+ years ago in this dialect (and often not followed through on) I'd rather not leave a whole new paper trail for future developers. What I've actually been considering for a while is making a 2-step pass that
I think that would be quite powerful. The nice thing about |
I'm sorry, but this response is a bit confusing to me. I'm very aware of the previous prs by James, because James and I are aligned on what we are working towards. I didn't chime in on those prs because I didn't want two people with the same ideas , from the same team and working on similar things to just review things between them and land things, given that we have multiple maintainers for Vector dialect. I do skim by those prs, I just didn't read until the title of this PR that we came to a conclusion that these patterns should be canonicalizations. In fact, i'm more confused that we talk about context here, given that the same change was reverted in the past and I distinctly remember we never came to conclusion what is the right form here: #72918 . I do not see that we resolved any of the concerns in that discussion, so I don't think concluding this is beneficial is correct, without previous context. There was even an RFC where @banach-space and @dcaballe both were present: https://discourse.llvm.org/t/improving-handling-of-unit-dimensions-in-the-vector-dialect/75216/29 .I don't see anybody followup on any of those discussions, so we cannot throw those conclusions away and proceed with something a lot of reviewers disagreed with.
Given that there was already an RFC about this, with no followup on things, yes we need an RFC on this. A lot of vector dialect users were present in that discussion whose concerns need to be addressed. I'm not okay with this proceeding as a PR given that there was an RFC about this where this PR was not the conclusion.
I don't know when we concluded that this "experiment" failed. I see that in that RFC @nicolasvasilache had the same point about vector.shape_cast https://discourse.llvm.org/t/improving-handling-of-unit-dimensions-in-the-vector-dialect/75216/10 . This is why we need to address the points in the RFC before we start a discussion on these patterns being canonicalizations. |
This is exactly the situation we are trying to improve! There have been numerous complaints along the years about the arbitrary set of disconnected patterns that we have in the Vector dialect and how difficult is for someone new (and I would say for someone with experience starting a new compiler) to put all of them together and in the right order to produce an optimized output. I’m ok with having different passes that are able to convert one form to another but we need a common starting point, which should be the canonical one.
That RFC is from 2023. Quite a few things have changed since then. and The RFC referenced above is also a good example of what can happen when we pull too many voices into a discussion. The discussion derailed and there were people talking over each other with counter-proposals and no conclusion was reached. I know not everyone agrees with this but the Tensor WG was built with the intent to improve on this situation and help build certain level of consensus before having a broader discussion. IMO, we should start there. Hopefully @rengolin or @rolfmorel can help organize. I think some people are on vacation right now so we may have to wait for a few days/weeks. |
|
Returning after some OOO + sick leave. I wanted to clarify my earlier feedback.
On context: @Groverkss ,I wasn’t suggesting you weren’t aware of previous PRs. My intent was to provide context for other reviewers (including, but not limited to you) so they could see that this work is part of a longer effort, not an isolated change. And where my support was coming from.
On the 2023 concerns: The PR you linked was ~2 years ago. It’s not clear to me whether the issues raised there still apply. If they do, could someone summarize them for us here? The earlier discussions:
involved many people and spanned months, but there’s been little follow-up since. That suggests (to me) that the concerns may no longer be relevant, especially given the changes in the ecosystem that Diego pointed out.
On process: I share the worry that RFCs can stall progress. We’ve seen this before (e.g. the 0-D vectors RFC that took 6 months and required an ODM). That said, I’m not against RFCs, just cautious that we don’t end up in the same situation again. The Tensor WG seems like a good venue to first align and then escalate if needed. Next step: If anyone believes the 2023 concerns still block this direction, could you restate them here? That would help us decide whether an RFC is warranted, or whether we can move forward with this canonicalization. In the meantime, I added this to the Tensor WG agenda. Thank you, |
While this was on the agenda for our call today, we didn't get a chance to discuss it (notes from today: https://discourse.llvm.org/t/tcdg-tensor-compiler-design-group-meeting-notes-2025-09-17) - we prioritised the unrolling vs flattening topic that was raised earlier. We'll get back to this in the next Tensor WG meeting! |
While we lacked The overall sentiment was that we should proceed with this unless specific and actionable concerns are raised. Regarding the earlier discussion on Discourse, much has changed since then, and no concrete follow-ups occurred. Given the ~2-year gap, it’s reasonable to assume that the previous concerns no longer hold. Previously, Given LLVM’s liberal revert policy, I propose that we land this and, if issues arise, revert. We’ve been converging toward this change for some time. -Andrzej |
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, the overall direction makes a lot of sense to me, massive thanks for working on this!
|
|
||
| if (!isOrderPreserving(transpose)) { | ||
| return rewriter.notifyMatchFailure( | ||
| transpose, "not order preserving, so not semantically a 'copy'"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you mean "semantically" without "not"?
| transpose, "not order preserving, so not semantically a 'copy'"); | |
| transpose, "not order preserving, so semantically a 'copy'"); |
| /// to fold shape_cast into create_mask. | ||
| /// | ||
| /// BEFORE: | ||
| /// to fold shape_cast into creatto a greater number of BEFORE: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something has gone wrong here :)
Hi, I'll try to write up concerns by Monday next week and ask for an RFC on how those concerns are addressed. I have raised concerns about this, other people have raised concerns about this, and if you want to proceed please send a RFC. I am happy to summarise all problems with this change by Monday. Landing this without a RFC even when requested would be against the LLVM policy. |
Groverkss
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change has had problems in the past with concerns raised from @MaheshRavishankar @kuhar @qedawkins and me. I have talked with them privately recently, and we still have the same concerns. I will summarise the previous concerns by Monday.
Can you please send a RFC summarising what has changed and how you plan to handle the previously raised concerns? Happy to send a summary of the concerns by Monday which you can use.
Please do not land this without sending a RFC and getting explicit approval on the RFC.
|
I did write a summary of previous concerns today, but I'm talking to others who were on previous threads to make sure if the concerns are still valid (some of them don't look valid to me). I'll post an update tommorow, due to timezone differences I couldn't sync with some people from US timezone today. |
|
The main issue with this canonicalization comes down to the fact that it drops The crux of the issue raised in the previous prs, reverts and discourse This is visible when describing each of these operations based on how their Note that transpose, broadcast and extract represent a projected permutation When coming from tensor land via vectorization, the vector code produced This PR tries to canonicalize different ops which can do different things but have The reasonings I'm seeing in this PR are:
I don't see canonicalizing towards a less restricted op to be a goal of Infact, that is only just the initial bar for something to be considered as a
It's not the bar for something to be a canonicalization. The goal that is relevant here, that I do see is:
There is no justification in this PR that canonicalizing towards a single operation Let's say I'm writing a transformation (a commonly written one), which folds It is easy to write, with just matching operations that have transpose like Not only does my transformation has to look for transpose-like ops, it has to This is making subsequent transformations harder to apply, this is actively going In practice, we have number of patterns in SPIRV backend in our downstream Some examples (from @kuhar): https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp#L148 I'm sure I could find more examples from upstream, but these are just from Keeping this in mind, I don't agree this transformation is a canonicalization If you think otherwise, please send a RFC and we can discuss further there. This The reason I ask for an RFC is that there are justifications given in this PR
A simple example of why this isn't a good arguement is we also have linalg.contraction, linalg.generic and linalg.matmul in the same dialect, but we don't choose to canonicalize them to each other. We have 3+ ways of writing a matmul in linalg. We should be consistent with this argument and apply it everywhere if so. What I would like to see in the RFC is:
|
"In general" having to match a single form for the same program is what makes subsequent analysis more effective: that is we can match a single "thing" instead of needing to match many. So this sentence as-is is a bit odd to me (but you're elaborating pretty well in the rest of your post, so don't take this as a rebuttal of your post as a whole but just a minor nitpick of this particular sentence). |
|
Hi Kunwar, thanks for your message - and apologies for the slow reply, it’s been a hectic week. I’d like to catch up F2F at the LLVM Dev Meeting next week with as many folks as possible involved in the discussion. I’ll propose some concrete next steps after that. |
Kunwar wont be there. Quinn and I are there and I am sure we can represent this view point there. |
|
I had written a reply but I agree that this needs high-bandwidth discussion. I'll also be around next week! |
|
Hi everyone, Here’s a quick summary from the LLVM Dev discussion (myself, @dcaballe , @MaheshRavishankar and @matthias-springer were present): We’ve reviewed this extensively, and I’ve seen no evidence that this change would:
Neither Diego nor I see any technical issues with this PR - every concern raised so far can be addressed without difficulty (**). As one specific point, let me return to Kunwar’s example, i.e. a pattern matching: vector.transpose 1x2x2x8 -> 2x2x8x1
vector.contract { indexing_maps = [#map] }to rewrite that as: vector.contract { indexing_maps = [#transposed_map] }Such a pattern should also match: vector.shape_cast 8x8 -> 2x2x8x1
vector.contract { indexing_maps = [#map] }Why? Because in this case Ultimately, this PR is about agreeing on the canonical form. My view remains that operations not requiring data movement should canonicalise to Both Diego and I support this change, but there is still disagreement on whether it qualifies as canonical (***). @ftynse 's talk on “Normal Forms for MLIR” from last week (****) offers a promising direction, but we’ll likely need to wait for Alex to implement it and propose an RFC. While that means progress here may pause for now, this discussion has helped clarify the underlying principles. Thank you for the discussion and for all the input. If I missed or misinterpreted something, please post a clarification. HTH -Andrzej (*) Mahesh offered to share an example, so I may revisit this point. |
|
Thank you @banach-space for summarizing the discussion. I think I vehemently agree with this
Specifically
is a nice formulation of the problem. |
I talked to @ftynse after the talk, and it seems that canonicalization is orthogonal to this concept of normal form, and I really don't see how canonicalization as a concept would go away.
This is not a rebuttal of canonicalization: we can try to apply principles to pick one or another form, but if at the end of everything two forms are completely equivalent, picking one is just a convention (e.g. "constants are moved to the right of commutative operations"). |
I think we all agree that choosing a canonical form could be a matter of preference but that's not the issue we're discussing here. Blocking this proposal means we will NOT have a canonical form at all. This leaves developers having to handle all potential variants AFTER running the canonicalizer which is not a good state to be for anyone. Adding to what Andrzej mentioned, I'd like to share some key ideas and points that were made during our F2F discussion:
To reiterate our primary goal here: we're not trying to break any downstream projects but to improve the overall situation for all the Vector dialect users. The shape cast problem has been a recurring pain point as previous RFCs and discussions demonstrate. There have been multiple "energy functions" trying to improve the overall situation for everyone but they have been shut down without offering a sound alternative. Even at this point, I'm still failing to see a strong technical justification for the block. |
|
I was hoping that we had reached some sort of understanding, but the earlier response seems to get back to the same loop of discussion. I'd like to break out of it. We are in agreement about almost all of it, except for the aspect of canonicalization, and I will try once again to clarify what has been said earlier.
I know, this is on me, I havent collected it yet, but I will try to dig this up.
I disagree that it is. Just because they could come in some input program does not make it a high-level abstraction. Higher level abstractions are necessarily operations that carry "more information in a succinct manner". You could have Linalg ops and affine ops exist in the same input program. That does not make them at the same level of abstraction.
Again, I disagree. This is not about IREE, and not relevant to the conversation. The IREE project has enough people with the background needed to handle these things. If anything, folks working on IREE are bringing their perspective of building a full scale compiler using these constructs to bear here cause we see how more directly the impact of choices made upstream. The IREE project has ways to deal with any changes that dont align with IREE. This has nothing to do with tile + fuse since this is at vector level. To restate, overall we are on the same page in terms of shape_casts as the right abstraction to handle unit dim folding before lowering to LLVM.
This is the crux of the problem. Canonicalization in MLIR are applied repeatedly and without control. That is how they are designed. There have been several talks by Matthias, and Alex that talk about how canonicalizations can silent change your compilers behavior just on an LLVM bump. Canonical form implies it is better always to chose this form over others, and when you cant prove that, you cant define the canonical form. "Just pick one" is not a robust enough solution long term (someone else can come along and pick something else, and there is no forcing function that disallows that). The fact that canonicalizations run everywhere without control + canonicalizations in MLIR have become a kitchen sink of patterns is really the problem.
I dont think it does. The whole point of Alex's talk (and this was the motivation for normal forms) is that in a lot of cases, it is impossible to decide canonical forms.
Again, lets leave IREE out of it. How IREE handles thing is completely irrelevant to this discussion.
I understand you are not trying to break downstream projects, and you see this as a way to improve overall situation. There is a disagreement that making this "canonical" does that. The only disagreement is the requirement that we have this as canonicalization. Lots of folks here disagree that it is. We dont need to block progress on things just for this one point about picking a canonical form. Lot of things added here are definitely useful for many downstream projects. Particularly, as a sequence of transformations to convert vector to LLVM dialect, having a patterns that convert broadcasts, transposes, etc. to shape casts and folds them away is extremely valuable. We are just blocking here cause we want to make some representation canonical. Instead we could provide a sequence of transformations that will take a vector dialect program and convert it to LLVM, which is readily usable everywhere. I would suggest we go down that path instead of just going in circles trying to decide on a canonical form. |
|
@dcaballe and @banach-space see this post here https://discourse.llvm.org/t/rfc-update-to-general-design-section-of-operation-canonicalizations-in-mlir/79355?u=maheshravishankar . This talks about the how vector.transpose captures more information than a vector.shape_cast and how you cannot always go from shape_cast to transpose. This is exactly the issue with treating vector.shape_cast as "canonical" representation for transposes and hoping that we can lift back to the original representation always. |
|
Thanks @MaheshRavishankar , as promised I am returning to this after you've shared your example.
I've extracted this repro as something representative (*): func.func @transpose_to_shape_cast_1(%0 : vector<4x1x1xf32>) -> vector<1x4x1xf32> {
%res = vector.transpose %0, [2, 0, 1] : vector<4x1x1xf32> to vector<1x4x1xf32>
return %res : vector<1x4x1xf32>
}
// -----
func.func @transpose_to_shape_cast_2(%0 : vector<4x1x1xf32>) -> vector<1x4x1xf32> {
%res = vector.transpose %0, [1, 0, 2] : vector<4x1x1xf32> to vector<1x4x1xf32>
return %res : vector<1x4x1xf32>
}QUESTION/COMMENT: Aren't the examples above identical operations? YES - LLVM example! # Canonicalize to vector.shape_cast, then lower.
$ mlir-opt repro.mlir -canonicalize -test-lower-to-llvm --split-input-file
# Lower as vector.transpose.
$ mlir-opt repro.mlir -test-lower-to-llvm --split-input-fileIn both cases I get the following (testing using this PR): module {
llvm.func @transpose_to_shape_cast_1(%arg0: !llvm.array<4 x array<1 x vector<1xf32>>>) -> !llvm.array<1 x array<4 x vector<1xf32>>> {
%0 = llvm.mlir.poison : !llvm.array<1 x array<4 x vector<1xf32>>>
%1 = llvm.extractvalue %arg0[0, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
%2 = llvm.insertvalue %1, %0[0, 0] : !llvm.array<1 x array<4 x vector<1xf32>>>
%3 = llvm.extractvalue %arg0[1, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
%4 = llvm.insertvalue %3, %2[0, 1] : !llvm.array<1 x array<4 x vector<1xf32>>>
%5 = llvm.extractvalue %arg0[2, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
%6 = llvm.insertvalue %5, %4[0, 2] : !llvm.array<1 x array<4 x vector<1xf32>>>
%7 = llvm.extractvalue %arg0[3, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
%8 = llvm.insertvalue %7, %6[0, 3] : !llvm.array<1 x array<4 x vector<1xf32>>>
llvm.return %res : !llvm.array<1 x array<4 x vector<1xf32>>>
}
}
// -----
module {
llvm.func @transpose_to_shape_cast_2(%arg0: !llvm.array<4 x array<1 x vector<1xf32>>>) -> !llvm.array<1 x array<4 x vector<1xf32>>> {
%0 = llvm.mlir.poison : !llvm.array<1 x array<4 x vector<1xf32>>>
%1 = llvm.extractvalue %arg0[0, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
%2 = llvm.insertvalue %1, %0[0, 0] : !llvm.array<1 x array<4 x vector<1xf32>>>
%3 = llvm.extractvalue %arg0[1, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
%4 = llvm.insertvalue %3, %2[0, 1] : !llvm.array<1 x array<4 x vector<1xf32>>>
%5 = llvm.extractvalue %arg0[2, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
%6 = llvm.insertvalue %5, %4[0, 2] : !llvm.array<1 x array<4 x vector<1xf32>>>
%7 = llvm.extractvalue %arg0[3, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
%res = llvm.insertvalue %7, %6[0, 3] : !llvm.array<1 x array<4 x vector<1xf32>>>
llvm.return %res : !llvm.array<1 x array<4 x vector<1xf32>>>
}
}Note, YES - SPIR-V example! # Canonicalize to vector.shape_cast, then lower.
$ mlir-opt repro.mlir -canonicalize -test-convert-to-spirv --split-input-file
# Lower as vector.transpose.
$ mlir-opt repro.mlir -test-convert-to-spirv --split-input-fileIn both cases I get the following (testing using this PR): module {
func.func @transpose_to_shape_cast_1(%arg0: vector<1xf32>, %arg1: vector<1xf32>, %arg2: vector<1xf32>, %arg3: vector<1xf32>) -> (vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>) {
return %arg0, %arg1, %arg2, %arg3 : vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>
}
}
// -----
module {
func.func @transpose_to_shape_cast_2(%arg0: vector<1xf32>, %arg1: vector<1xf32>, %arg2: vector<1xf32>, %arg3: vector<1xf32>) -> (vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>) {
return %arg0, %arg1, %arg2, %arg3 : vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>
}
}SPIR-V makes it even clearer that we are dealing with a NO-OP 😅 FINAL THOUGHTS I argue that in all cases we are dealing with one operation for which we have multiple names ( I obviously might be missing something - please correct me know if that's the case. I am sharing this to make my mental model clear and to avoid confusion. -Andrzej (*) Please provide other examples if this does not capture what you had in mind. |
|
@banach-space I think I might not have communicated the intent of my example properly. This was more to show why It of course lowers to the same thing if you are just lowering to LLVM or SPIR-V. I have already agreed that while lowering to LLVM you should lower both these transposes to |
|
I think there is a misunderstanding about what we can expect from a canonical form. A canonical form should indeed allow us to convert to any equivalent representation. However, while these representations are semantically equivalent, the canonical form doesn't (and shouldn't need to) preserve information about which specific representation was part of the input IR. To illustrate this with a simple but realistic example, consider different ways to represent "multiplication by 2" (I haven't checked but this is probably something that LLVM canonicalizes today to a single form):
If we choose option 2 as our canonical form, we can certainly convert to both options 1 and 3 from it when needed. What we can't do (and what isn't a requirement for canonical forms) is to automatically know which of these three forms the input IR had without any additional context. Bringing this back to the The key point here is: if preserving exactly the original input representation is important for your use case, then canonicalization is not the right transformation to apply at that stage of your pipeline. That is not the right expectation to have for a canonical form. |
|
I found this "visualization" from Cursor quite illustrative: Data Layout DiagramOriginal Vector:
|
Answering my own question, I can think of one use case: any kind of traversal that needs to track or propagate a property across one of the unit dimensions in the example wouldn't be able to do so with the Conclusion 1 : We can’t canonicalize a transpose operation to a shape cast when multiple unit dimensions are transposed. The data layout or dimension mapping across the operation becomes ambiguous with the Great, that’s progress! We have identified something technical specific. I suggest that we continue focusing on the technical aspects of the different IR forms. Could we continue this exercise? Could we come up with similar examples for:
|
I am happy cursor was able to give you a better explanation of what I was trying to say all this while. Good to have reached this common state. I think we did discuss previously, then stating that "certain" transpose/broadcasts are canonically shape_casts, and forcing them to then become shape_casts without control is now creating unnecessary complication in the definition of canonicalization. If some transformation is relying on following dimensions through broadcasts/transposes, now it has to look at a shape_cast, decide if this is "convertible to a transpose/broadcast" and then handle that appropriately. This does not seem like a great setup. |
|
Thank you for the detailed discussion, @dcaballe - that was very helpful in clarifying the underlying issues.
Agreed. That’s one concrete example (*). @MaheshRavishankar, could you help us identify other specific cases so that we can better scope or constrain this change? All in all, given the nuances discussed, I don’t see a specific blocker preventing this from being merged - or is there?
I think we may have to agree to disagree here. That said, as -Andrzej (*) A transpose operation to a shape cast when multiple unit dimensions are transposed. |
We might be having different reads of the blocker. To me this discussion is uncovering more reasons why this change shouldnt be merged (this kind of thing is what I was saying would be an issue from the get go).
I want to re-iterate : this is not about my use case. We can find ways to work around things either way. So I am disagree-ing more with the approach here, rather than "this doesnt fit my use case". |
Discussions suggest that we should use shape_cast as a canonical form of broadcast/transpose/extract where possible (see #138777)
For example these can all be expressed as shape casts:
This PR adds canonicalizes to convert the above 3 examples to shape_casts.
I've added some more comments as review comments.
I'm happy to split this PR up and add the new patterns separately.