Skip to content

Conversation

newling
Copy link
Contributor

@newling newling commented May 19, 2025

Discussions suggest that we should use shape_cast as a canonical form of broadcast/transpose/extract where possible (see #138777)

For example these can all be expressed as shape casts:

%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
%1 = vector.transpose %arg1, [1, 0] : vector<2x1xi8> to vector<1x2xi8>
%2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>

This PR adds canonicalizes to convert the above 3 examples to shape_casts.

I've added some more comments as review comments.

I'm happy to split this PR up and add the new patterns separately.

@newling newling changed the title [vector][mlir] Canonicalize to shape_cast where possible [wip][vector][mlir] Canonicalize to shape_cast where possible May 19, 2025
@newling newling changed the title [wip][vector][mlir] Canonicalize to shape_cast where possible [vector][mlir] Canonicalize to shape_cast where possible May 19, 2025
@newling newling force-pushed the canonicalize_to_shape_cast branch from d546ab3 to 29d41d8 Compare June 5, 2025 18:07
@newling newling force-pushed the canonicalize_to_shape_cast branch from 29d41d8 to f2e5417 Compare June 25, 2025 23:03
Copy link

github-actions bot commented Jun 25, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@banach-space I'm getting back to this PR. Peephole question: is this operation ok? i.e. is

vector.shape_cast %a vector<[4]x1xf32> to vector<1x[4]xf32>

an acceptable operation to have after running mlir-opt -arm-sme-vector-legalization -cse -canonicalize ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, yes. But I can't guarantee there's no logic that expects vector<[4]x1xf32> instead of vector<1x[4]xf32> ;-) If that's the case, we will fix it and I will be grateful for uncovering this :)

@newling newling force-pushed the canonicalize_to_shape_cast branch from 7bc5da0 to e673522 Compare June 26, 2025 15:31
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Author note: I've removed this, as now it happens in 2 steps during canonicalization. The first converts the Broadcast to a ShapeCast. The second combines the 2 ShapeCasts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Author note: I've removed this, as it now happens in 2 steps during canonicalization. The first (new) step is to rewrite the transpose as a shape_cast. The second step is to fold shape_cast(shape_cast) to shape_cast.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Author note: I've removed this pattern, as it is a special case of TransposeToShapeCast

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Author note: removed these tests, as the pattern they are testing is removed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we keep them? shouldn't they still be canonicalized?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add them back, yes they're still canonicalized

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Author note: as the vector.transpose is canonicalized to a vector.shape_cast, the lowering test is now moved to shape_cast lowering

@llvmbot
Copy link
Member

llvmbot commented Jun 26, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir-sme

Author: James Newling (newling)

Changes

Discussions suggest that we should use shape_cast as a canonical form of broadcast/transpose/extract where possible (see #138777)

For example these can all be expressed as shape casts:

%0 = vector.broadcast %arg0 : vector&lt;4xi8&gt; to vector&lt;1x1x4xi8&gt;
%1 = vector.transpose %arg1, [1, 0] : vector&lt;2x1xi8&gt; to vector&lt;1x2xi8&gt;
%2 = vector.extract %arg2[0] : vector&lt;4xi8&gt; from vector&lt;1x4xi8&gt;

This PR adds canonicalizes to convert the above 3 examples to shape_casts.

I've added some more comments as review comments.

I'm happy to split this PR up and add the new patterns separately.


Patch is 41.81 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140583.diff

10 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+84-53)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp (-61)
  • (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (+4-4)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+26-41)
  • (modified) mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir (+2-2)
  • (added) mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir (+162)
  • (modified) mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir (+60)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir (+5-7)
  • (modified) mlir/test/Dialect/Vector/vector-transpose-lowering.mlir (-85)
  • (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+4-4)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 862ed7bae1fbb..08cc4af158e10 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2351,11 +2351,41 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
   return success();
 }
 
+/// BEFORE:
+/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType sourceType = extractOp.getSourceVectorType();
+    VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+    if (!outType)
+      return failure();
+
+    // Negative values in `position` indicates poison, which cannot be
+    // represented with a shape_cast
+    if (llvm::any_of(extractOp.getMixedPosition(),
+                     [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
+      return failure();
+
+    if (sourceType.getNumElements() != outType.getNumElements())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
+                                                     extractOp.getVector());
+    return success();
+  }
+};
+
 } // namespace
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+  results
+      .add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
+          context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
 }
@@ -2867,13 +2897,36 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
     return success();
   }
 };
+
+/// BEFORE:
+/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+struct BroadcastToShapeCast final
+    : public OpRewritePattern<vector::BroadcastOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
+                                PatternRewriter &rewriter) const override {
+    auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
+    if (!sourceType) {
+      return rewriter.notifyMatchFailure(
+          broadcast, "source is a scalar, shape_cast doesn't support scalar");
+    }
+
+    VectorType outType = broadcast.getType();
+    if (sourceType.getNumElements() != outType.getNumElements())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
+                                                     broadcast.getSource());
+    return success();
+  }
+};
 } // namespace
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
-  // calling `populateCastAwayVectorLeadingOneDimPatterns`
-  results.add<BroadcastFolder>(context);
+  results.add<BroadcastFolder, BroadcastToShapeCast>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -5991,10 +6044,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
   }
 };
 
-/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
-///   i) Y = ShapeCast(X), or
-///  ii) Y = Broadcast(X)
-/// If both (i) and (ii) are possible, (i) is chosen.
+/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
@@ -6009,22 +6059,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
     auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
     bool srcIsScalar = !srcVectorType;
 
-    // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
-    // Example:
-    // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
-    // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
-    // to
-    // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
-    if (srcVectorType) {
-      if (srcVectorType.getNumElements() ==
-          shapeCastOp.getResultVectorType().getNumElements()) {
-        rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
-            shapeCastOp, shapeCastOp.getResultVectorType(),
-            broadcastOp.getSource());
-        return success();
-      }
-    }
-
     // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
     // Example
     // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
@@ -6233,7 +6267,7 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
   // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
   // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
   //
-  // Example of what NOT to fold:
+  // Example of what not to fold:
   // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
   //
   if (getSourceVectorType() == getResultVectorType() &&
@@ -6359,32 +6393,6 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
   }
 };
 
-/// Folds transpose(shape_cast) into a new shape_cast.
-class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(TransposeOp transposeOp,
-                                PatternRewriter &rewriter) const override {
-    auto shapeCastOp =
-        transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
-    if (!shapeCastOp)
-      return failure();
-    if (!isOrderPreserving(transposeOp))
-      return failure();
-
-    VectorType resultType = transposeOp.getType();
-
-    // We don't need to check isValidShapeCast at this point, because it is
-    // guaranteed that merging the transpose into the the shape_cast is a valid
-    // shape_cast, because the transpose just inserts/removes ones.
-
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
-                                                     shapeCastOp.getSource());
-    return success();
-  }
-};
-
 /// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
 /// 'order preserving', where 'order preserving' means the flattened
 /// inputs and outputs of the transpose have identical (numerical) values.
@@ -6480,12 +6488,35 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
   }
 };
 
+/// BEFORE:
+/// %0 = vector.transpose %arg0, [0, 2, 1] :
+///                   vector<2x1x2xf32> to vector<2x2x1xf32>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 :
+///                   vector<2x1x2xf32> to vector<2x2x1xf32>
+struct TransposeToShapeCast final
+    : public OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::TransposeOp transpose,
+                                PatternRewriter &rewriter) const override {
+
+    if (!isOrderPreserving(transpose)) {
+      return rewriter.notifyMatchFailure(
+          transpose, "not order preserving, so not semantically a 'copy'");
+    }
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+        transpose, transpose.getType(), transpose.getVector());
+    return success();
+  }
+};
+
 } // namespace
 
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
-  results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
-              FoldTransposeSplat, FoldTransposeBroadcast>(context);
+  results.add<FoldTransposeBroadcast, FoldTransposeCreateMask,
+              FoldTransposeSplat, TransposeFolder, TransposeToShapeCast>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 732e316c93381..71410eda28297 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -11,7 +11,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -382,64 +381,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
   vector::VectorTransposeLowering vectorTransposeLowering;
 };
 
-/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
-/// to 2D vectors with at least one unit dim. For example:
-///
-/// Replace:
-///   vector.transpose %0, [1, 0] : vector<4x1xi32>> to
-///                                 vector<1x4xi32>
-/// with:
-///   vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
-///
-/// Source with leading unit dim (inverse) is also replaced. Unit dim must
-/// be fixed. Non-unit dim can be scalable.
-///
-/// TODO: This pattern was introduced specifically to help lower scalable
-/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
-/// to cancel out) would be preferable:
-///
-///  BEFORE:
-///     %0 = some_op
-///     %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
-///     %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-///  AFTER:
-///     %0 = some_op
-///     %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
-///
-/// Given the context above, we may want to consider (re-)moving this pattern
-/// at some later time. I am leaving it for now in case there are other users
-/// that I am not aware of.
-class Transpose2DWithUnitDimToShapeCast
-    : public OpRewritePattern<vector::TransposeOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
-                                    PatternBenefit benefit = 1)
-      : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
-
-  LogicalResult matchAndRewrite(vector::TransposeOp op,
-                                PatternRewriter &rewriter) const override {
-    Value input = op.getVector();
-    VectorType resType = op.getResultVectorType();
-
-    // Set up convenience transposition table.
-    ArrayRef<int64_t> transp = op.getPermutation();
-
-    if (resType.getRank() == 2 &&
-        ((resType.getShape().front() == 1 &&
-          !resType.getScalableDims().front()) ||
-         (resType.getShape().back() == 1 &&
-          !resType.getScalableDims().back())) &&
-        transp == ArrayRef<int64_t>({1, 0})) {
-      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
-      return success();
-    }
-
-    return failure();
-  }
-};
-
 /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
 /// If the strategy is Shuffle1D, it will be lowered to:
 ///   vector.shape_cast 2D -> 1D
@@ -511,8 +452,6 @@ class TransposeOp2DToShuffleLowering
 void mlir::vector::populateVectorTransposeLoweringPatterns(
     RewritePatternSet &patterns,
     VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) {
-  patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
-                                                  benefit);
   patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
       vectorTransposeLowering, patterns.getContext(), benefit);
 }
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index 6cdf576272ebc..a9a2fdccdd82f 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -480,11 +480,11 @@ func.func @lift_illegal_transpose_to_memory_with_in_bounds_attr(%a: index, %b: i
 
 // -----
 
-// The pass should do nothing (and not crash).
-// CHECK-LABEL: @illegal_transpose_no_defining_source_op
-func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32>
+// CHECK-LABEL: @transpose_no_defining_source_op
+func.func @transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32>
 {
-  // CHECK: vector.transpose
+  // CHECK:      vector.shape_cast
+  // CHECK-SAME: vector<[4]x1xf32> to vector<1x[4]xf32>
   %0 = vector.transpose %vec, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
   return %0 : vector<1x[4]xf32>
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 65b73375831da..374c71c814e89 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -451,16 +451,25 @@ func.func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>,
 // -----
 
 // CHECK-LABEL: transpose_3D_identity
-// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
+//  CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
+//  CHECK-NEXT: return [[ARG]]
 func.func @transpose_3D_identity(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {
-  // CHECK-NOT: transpose
   %0 = vector.transpose %arg, [0, 1, 2] : vector<4x3x2xf32> to vector<4x3x2xf32>
-  // CHECK-NEXT: return [[ARG]]
   return %0 : vector<4x3x2xf32>
 }
 
 // -----
 
+// CHECK-LABEL: transpose_0D_identity
+//  CHECK-SAME: ([[ARG:%.*]]: vector<i8>)
+//  CHECK-NEXT: return [[ARG]]
+func.func @transpose_0D_identity(%arg : vector<i8>) -> vector<i8> {
+  %0 = vector.transpose %arg, [] : vector<i8> to vector<i8>
+  return %0 : vector<i8>
+}
+
+// -----
+
 // CHECK-LABEL: transpose_2D_sequence
 // CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>)
 func.func @transpose_2D_sequence(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
@@ -753,12 +762,13 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
 
 // -----
 
+
 // CHECK-LABEL: negative_fold_extract_broadcast
-//       CHECK:   vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
-//       CHECK:   vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+//       CHECK:   vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x2x4xf32>
+//       CHECK:   vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x2x4xf32>
 func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> {
-  %b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32>
-  %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+  %b = vector.broadcast %a : vector<1x1xf32> to vector<1x2x4xf32>
+  %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x2x4xf32>
   return %r : vector<4xf32>
 }
 
@@ -797,8 +807,8 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
 // rank(extract_output) < rank(broadcast_input)
 func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
   %idx0 : index, %idx1 : index) -> vector<4xf32> {
-  %b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
-  %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
+  %b = vector.broadcast %a : vector<2x4xf32> to vector<2x2x4xf32>
+  %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<2x2x4xf32>
   return %r : vector<4xf32>
 }
 
@@ -1033,30 +1043,6 @@ func.func @canonicalize_broadcast_shapecast_to_broadcast_scalar(%arg0: f32) -> v
 
 // -----
 
-// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is.
-// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapcast
-//   CHECK-NOT:   vector.broadcast
-//       CHECK:   vector.shape_cast {{.+}} : vector<2xf32> to vector<1x2x1xf32>
-func.func @canonicalize_broadcast_shapecast_to_shapcast(%arg0 : vector<2xf32>) -> vector<1x2x1xf32> {
-  %0 = vector.broadcast %arg0 : vector<2xf32> to vector<1x2xf32>
-  %1 = vector.shape_cast %0 : vector<1x2xf32> to vector<1x2x1xf32>
-  return %1 : vector<1x2x1xf32>
-}
-
-// -----
-
-// In this test, broadcast (1)->(1,1) and shape_cast (1)->(1,1) are both legal. shape_cast is chosen.
-// CHECK-LABEL: func @canonicalize_broadcast_shapecast_both_possible
-//   CHECK-NOT:   vector.broadcast
-//       CHECK:   vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32>
-func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) -> vector<1x1xf32> {
-    %0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32>
-    %1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32>
-    return %1 : vector<1x1xf32>
-}
-
-// -----
-
 // CHECK-LABEL: fold_vector_transfer_masks
 func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -1920,12 +1906,12 @@ func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
 
 // -----
 
-// CHECK-LABEL: func @insert_extract_to_broadcast
+// CHECK-LABEL: func @insert_extract_to_shape_cast
 //  CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-//       CHECK:   %[[V0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
-//       CHECK:   %[[V1:.*]] = vector.broadcast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
+//       CHECK:   %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
+//       CHECK:   %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
 //       CHECK:   return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
-func.func @insert_extract_to_broadcast(%arg0 : vector<1x1x4xf32>,
+func.func @insert_extract_to_shape_cast(%arg0 : vector<1x1x4xf32>,
   %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
   %0 = vector.extract %arg0[0, 0] : vector<4xf32> from vector<1x1x4xf32>
   %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
@@ -2277,7 +2263,7 @@ func.func @shuffle_1d_rhs_poison() -> vector<4xi32> {
 
 // CHECK-LABEL: func @shuffle_canonicalize_0d
 func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
-  // CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
+  // CHECK: vector.shape_cast %{{.*}} : vector<i32> to vector<1xi32>
   %shuffle = vector.shuffle %v0, %v1 [0] : vector<i32>, vector<i32>
   return %shuffle : vector<1xi32>
 }
@@ -2764,9 +2750,8 @@ func.func @transfer_read_from_rank_reducing_extract_slice(%src: tensor<1x8x8x8xf
 // CHECK-LABEL: func.func @extract_from_broadcast
 func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
   %0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32>
-
-  //  CHECK-NEXT:   %0 = vector.extract {{.*}}[0, 0] : vector<1xf32> from vector<1x1x1xf32>
-  //  CHECK-NEXT:   return %0 : vector<1xf32>
+  //  CHECK-NEXT:   %[[RES:.*]] = vector.shape_cast{{.*}} vector<1x1x1xf32> to vector<1xf32>
+  //  CHECK-NEXT:   return %[[RES]] : vector<1xf32>
   %1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32>
   return %1: vector<1xf32>
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index fdab2a8918a2e..d5f96a8928770 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -81,8 +81,8 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
 
 // CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
 //  CHECK-SAME:       %[[A:.*]]: vector<1x2xi8>)
-//       CHECK:       %[[EXTRACT:.*]] = vector.extract %[[A]][0] : vector<2xi8> from vector<1x2xi8>
-//       CHECK:       return %[[EXTRACT]] : vector<2xi8>
+//       CHECK:       %[[SC:.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8>
+//       CHECK:       return %[[SC]] : vector<2xi8>
 func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
   %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
   %1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
diff --git...
[truncated]

@newling
Copy link
Contributor Author

newling commented Jun 26, 2025

Hi @banach-space and @dcaballe, I've pulled this PR out of draft mode, so please feel free to comment on it whenever!

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! LGTM in general. The only general comment is to make sure we don't reduce testing coverage. I think we should keep/update the tests even for those cases where the pattern is removed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep both tests, one with the original shape and one with the new ones?

Unrelated: it looks like we are missing a canonicalization patter here? This should be turned into a single vector.broadcast to vector<4xf32>?

Copy link
Contributor Author

@newling newling Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep both tests, one with the original shape and one with the new ones?

Makes sense, will do.

Unrelated: it looks like we are missing a canonicalization patter here? This should be turned into a single vector.broadcast to vector<4xf32>?

No because you can't broadcast <1x1xf32> to <4xf32> -- broadcasts can never reduce rank in Vector. FWIW slightly related to my comment here where this would be simpler if ops didn't do implicit shape casting. In this case if it was something like

 %s = vector.shape_cast %a : vector<1x1xf32> to vector<1x1x1xf32>
 %b = vector.broadcast %s : vector<1x1x1xf32> to vector<1x2x4xf32>
 %r = vector.extract %b[0, 0] : vector<1x1x4xf32> from vector<1x2x4xf32>
 %s = vector.shape_cast %r : vector<1x1x4> to vector<4>

ie if we constrained broadcasts and extracts to be rank retaining, then this would be canonicalized to

 %s = vector.shape_cast %a : vector<1x1xf32> to vector<1x1x1xf32>
 %b = vector.broadcast %s : vector<1x1x1xf32> to vector<1x1x4xf32>
 %s = vector.shape_cast %b : vector<1x1x4> to vector<4>

which, if you have faith that the shape_casts will vanish at a later point, is simpler!

p.s. I plan to reply in #145740 later today

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we keep them? shouldn't they still be canonicalized?

@banach-space
Copy link
Contributor

Thanks!

I run the SME e2e tests and all pass. I wasn't able to cherry-pick this in IREE though, getting weird compilation errors. Though upstream tests should be sufficient to surface all potential issues.

@newling , why not name all "folding" patterns as fold...? Wouldn't that be more consistent?

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Comment on lines 810 to 811
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change shapes?

@newling
Copy link
Contributor Author

newling commented Jun 27, 2025

Thanks!

I run the SME e2e tests and all pass. I wasn't able to cherry-pick this in IREE though, getting weird compilation errors. Though upstream tests should be sufficient to surface all potential issues.

@newling , why not name all "folding" patterns as fold...? Wouldn't that be more consistent?

I'll give this a spin with IREE

@newling , why not name all "folding" patterns as fold...? Wouldn't that be more consistent?

Yes, I think so. Actually fe3933d made me wonder if we should split canonicalize.mlir into 2 files (the new one with name fold.mlir containing everything in canonicalize.mlir that only depends on 1-time folds).

@banach-space and @dcaballe thanks for your feedback! Unfortunately I'm going to put this on hold again temporarily, as I've uncovered some other things which should be done before this. Moving back into draft mode, will ping when I think it's ready again.

@newling newling marked this pull request as draft June 27, 2025 00:51
@banach-space
Copy link
Contributor

Actually fe3933d made me wonder if we should split canonicalize.mlir into 2 files (the new one with name fold.mlir containing everything in canonicalize.mlir that only depends on 1-time folds).

+1

@newling
Copy link
Contributor Author

newling commented Aug 6, 2025

This PR is back, and ready for review!

Let me summarize the previous concerns as this is quite old now:

@dcaballe raised concerns about removing tests. I have reinstated all canonicalization tests.
@banach-space raised concerns about test naming and dir structure. I would prefer to address these in a later PR as part of a wider canonicalization/folding test refactor.
@banach-space noted this would likely cause ripples downstream, and suggested running IREE tests. I have done this, and indeed some lit tests fail. I will take responsibility for fixing these (FYI @Groverkss).

Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I didn't follow the discussion in the previous PR, but to me something seems wrong, that we are moving towards shape_cast being the canonical form for removing unit dimensions.

shape_cast by itself, is a more general operation and requires inferring what the shape_cast actually did. We are throwing away information in each of these examples:

1. %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
2. %1 = vector.transpose %arg1, [1, 0] : vector<2x1xi8> to vector<1x2xi8>
3. %2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>

For each of these cases, we now need to analyze the shape_cast to see what it was doing:

  1. It is clear that broadcast always does a leading dim broadcast, and we already know its a unit dim, there is no possibility of collapsing
  2. transpose makes it clear that we are doing a permutation, there is no expansion/collapse of rank happening
  3. extract makes it clear we are rank reducing the vector, there is no possibility of expansion

We are throwing away information, this is in no way a canonicalization. In fact, this should be preprocessing pattern to prepare for conversion / unrolling / flattening so that we have less ops to handle and can cancel out more operations, if we actually want to run something like this.

This can clearly cause problems where we cannot fold things properly because of a shape_cast in between:

%0 = vector.extract %arg0[0] : vector<1x4> from vector<8x1x4>
%1 = vector.extract %0[0] : vector<4> from vector<1x4>

Now, depending on how the patterns run, we could end up with:

%1 = vector.extract %arg0[0, 0] : vector<4> from vector<8x1x4>

or

%0 = vector.extract %arg0[0] : vector<1x4> from vector<8x1x4>
%1 = vector.shape_cast %0 : vector<4> from vector<1x4>

To actually write the canonicalizer of shape_cast(extract()), we will end up actually having to infer if the shape_cast was a vector.extract, which defeats the purpose of this canonicalization.

Also note that each of the cases mentioned are disjoint. They do not overlap in terms of what they can do and the ops carry restrictions on what they can do.

@newling
Copy link
Contributor Author

newling commented Aug 12, 2025

@Groverkss that's fair, I probably should have done this as an RFC on discourse as there's been some debate and drift since the original #138777

Maybe a less controversial approach is to use functions like isBroadcastLike more liberally in the canonicalizers. Can make isExtractLike and isTransposeLike quite easily, and hopefully using those instead of isavector::ExtractOp and isavector::TransposeOp will get us most of the way there. Not sure though.

@Groverkss
Copy link
Member

@Groverkss that's fair, I probably should have done this as an RFC on discourse as there's been some debate and drift since the original #138777

Maybe a less controversial approach is to use functions like isBroadcastLike more liberally in the canonicalizers. Can make isExtractLike and isTransposeLike quite easily, and hopefully using those instead of isavector::ExtractOp and isavector::TransposeOp will get us most of the way there. Not sure though.

I would really argue that the opposite pattern should be a canonicalization, since we are caching the analyzed information, by using a more restricted op.

@newling
Copy link
Contributor Author

newling commented Aug 12, 2025

I would really argue that the opposite pattern should be a canonicalization, since we are caching the analyzed information, by using a more restricted op.

Possibly, but I think "more restricted" depends on how you measure volumes in op-space (I discussed this here). But yeah, I think this canonicalization (in either direction) is not obvious enough to include until we have a formal definition (energy function for fixed point convergence).

@dcaballe
Copy link
Contributor

dcaballe commented Aug 12, 2025

The point of having a canonical representation is to choose a unique representation for a computation that may have multiple alternative representations. That form doesn't have to be the "best" representation (whatever that means: it doesn’t have to be the most efficient, the most compact, the most aesthetically pleasing, the most restrictive…) It has to be unique and hold all the information that the original form held. We currently have 3+ ways to drop unit dimensions and perform vector reshapes that don’t imply data movement in the virtual vector world and we are paying the cost of maintaining all those redundant patterns. I think the big misalignment here is that we expect the canonical form to be the ideal input for all our use cases without doing any pre-processing and that is something we need to change for the shake of the overall project health.

For me this is similar to other well-established canonical forms in LLVM. For example, look at llvm.vector.reverse, llvm.vector.extract (for sub-vectors), llvm.vector.insert (for-subvectors), llvm.vector.deinterleave, llvm.vector.interleave, sequences of regular extractelement, insertelement... All these restrictive intrinsics/instructions… they are all canonicalized to vector.shufflevector by InstCombine… and then there are methods to identify the specific shuffle patterns of each restrictive form, in case that’s needed. This is the same we are trying to do here.

shape_cast by itself, is a more general operation and requires inferring what the shape_cast actually did. We are throwing away information in each of these examples:

we now need to analyze the shape_cast to see what it was doing:

This sounds contradictory to me. If we can infer the information from the shape_cast itself, what is the information that is being thrown away? There shouldn't be a problem with analyzing the shape_cast op to get the information that we need, or convert that shape cast into something else before a specific pass/transformation. We are talking about comparing the rank of the input and the result types to determine if the reshape is collapsing/expanding/preserving dimensions. I fail to see the problem here.

It is clear that broadcast always does a leading dim broadcast, and we already know its a unit dim, there is no possibility of collapsing

Retrieving this information from a shape_cast is just about shapeCastInput.getRank() > shapeCastResult.getRank()

transpose makes it clear that we are doing a permutation, there is no expansion/collapse of rank happening

For unit dimension transposes, there is no data permutation. Again, retrieving the expansion collapse/information is just about shapeCastInput.getRank() == shapeCastResult.getRank().

extract makes it clear we are rank reducing the vector, there is no possibility of expansion

Same.

Now, depending on how the patterns run, we could end up with:
%1 = vector.extract %arg0[0, 0] : vector<4> from vector<8x1x4>
or
%0 = vector.extract %arg0[0] : vector<1x4> from vector<8x1x4>
%1 = vector.shape_cast %0 : vector<4> from vector<1x4>

This is just about having folder for vector.shape_cast(vector.extract) -> vector.extract, right?

@dcaballe
Copy link
Contributor

dcaballe commented Aug 12, 2025

I would really argue that the opposite pattern should be a canonicalization, since we are caching the analyzed information, by using a more restricted op.

Introducing new ops in the canonical form to cache compare(shapeCastInput.getRank(), shapeCastResult.getRank()) sounds an overkill to me...

@banach-space
Copy link
Contributor

What is this PR achieving?

to me something seems wrong, that we are moving towards shape_cast being the canonical form for removing unit dimensions

To me, this change isn't really about removing unit dims. Its about improving the state of Vector.

In the distant past (~4yrs ago) we did use shape_cast to remove unit dims and then effectively reverted that to restrict its usage. Sadly, that "experiment" didn't really work as intended - we do use vector.shape_cast a lot.

Without a crisp contract, vector.shape_cast remains a foot-gun. To me, this change is more about:

  • containing vector.shape_cast by refining + clarifying it's purpose + semantics.

Context + previous PRs

Note, James has already proposed/implemented a number of steps to improve shape_cast:

This PR simply takes things further (and was discussed before: #138777). I just wanted to highlight - there has been a number of PRs through which we concluded that this would be beneficial.

Why take this approach?

The point of having a canonical representation is to choose a unique representation for a computation that may have multiple alternative representations

Indeed! With multiple ways of expressing identical computation, lets just agree on one. I don't see any information being lost, so to me this a matter of simply agreeing on what the canonical form should be and then working with that.

From #138777:

Currently the vector dialect has no strict specification of which of 2 equivalent forms is more canonical

Yes! Clarifying this takes us to a cleaner, better defined state - that's beneficial for Vector.

As a specific argument for this change, replacing:

%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>

with a vector.shape_cast adds a useful hint: this broadcast is shape-only (no movement), and the op class itself now encodes “view-change”.

Is this really a canonicalization?

Yes.

The original Ops are mechanically recoverable in every case, so I don't see any loss of information here.

This is just about having folder for vector.shape_cast(vector.extract) -> vector.extract, right?

I believe so, yes.

Do we need an RFC?

I probably should have done this as an RFC on discourse

We are discussing fine implementation details within Vector. Since all the maintainers of Vector are already here, it's not clear to me whether taking this to Discourse would help. I’d prefer to keep momentum here unless there’s broader impact beyond vector.

LGTM?

Yes

@newling
Copy link
Contributor Author

newling commented Aug 14, 2025

I don't mind which way this PR goes. The original goal of the PR was to generate better code in a simpler way, not to define a canonical form. Converting to shape_cast achieves the original goal, but we don't have to say that shape_cast is a canonical form.

As already mentioned in previous comments, technically it is a canonicalization if we want it to be, but until we formalize a canonical form (ideally with an energy function with guarantees of converging to a fixed point, and clear documentation and motivation in a single obvious place) maybe we're better off leaving things that don't directly reduce op count or provide some other obvious generic improvement, out. As someone who's tried to unravel complex decisions make 4+ years ago in this dialect (and often not followed through on) I'd rather not leave a whole new paper trail for future developers.

What I've actually been considering for a while is making a 2-step pass that

  • first makes as much as possible a shape_cast. Included here is stripping out the implicit shape_cast before/after broadcasts/from_elements/gather/extract, etc.
  • let shape_casts cancel.
  • make as little as possible a shape_cast (i.e. shape_cast -> transpose, and reabsorb shape_casts into broadcasts/extracts).

I think that would be quite powerful. The nice thing about --canonicalize though is that users of Vector (even those who don't know they are) don't need to know about every optimization written in the past, and add it to their pipeline to get better code. As canonicalize is a 1-step pass, it can't include what I've outlined above. Also, canonicalize must be fast.

@Groverkss
Copy link
Member

Groverkss commented Aug 15, 2025

Note, James has already proposed/implemented a number of steps to improve shape_cast:

Lowering ([mlir][vector] Improve shape_cast lowering #140800)
Folding (#146368 + #135841)
This PR simply takes things further (and was discussed before: #138777). I just > wanted to highlight - there has been a number of PRs through which we concluded that this would be beneficial.

I'm sorry, but this response is a bit confusing to me. I'm very aware of the previous prs by James, because James and I are aligned on what we are working towards. I didn't chime in on those prs because I didn't want two people with the same ideas , from the same team and working on similar things to just review things between them and land things, given that we have multiple maintainers for Vector dialect. I do skim by those prs, I just didn't read until the title of this PR that we came to a conclusion that these patterns should be canonicalizations.

In fact, i'm more confused that we talk about context here, given that the same change was reverted in the past and I distinctly remember we never came to conclusion what is the right form here: #72918 . I do not see that we resolved any of the concerns in that discussion, so I don't think concluding this is beneficial is correct, without previous context. There was even an RFC where @banach-space and @dcaballe both were present: https://discourse.llvm.org/t/improving-handling-of-unit-dimensions-in-the-vector-dialect/75216/29 .I don't see anybody followup on any of those discussions, so we cannot throw those conclusions away and proceed with something a lot of reviewers disagreed with.

Do we need an RFC?

Given that there was already an RFC about this, with no followup on things, yes we need an RFC on this. A lot of vector dialect users were present in that discussion whose concerns need to be addressed. I'm not okay with this proceeding as a PR given that there was an RFC about this where this PR was not the conclusion.

In the distant past (~4yrs ago) we did use shape_cast to remove unit dims and then effectively reverted that to restrict its usage. Sadly, that "experiment" didn't really work as intended - we do use vector.shape_cast a lot.

I don't know when we concluded that this "experiment" failed. I see that in that RFC @nicolasvasilache had the same point about vector.shape_cast https://discourse.llvm.org/t/improving-handling-of-unit-dimensions-in-the-vector-dialect/75216/10 . This is why we need to address the points in the RFC before we start a discussion on these patterns being canonicalizations.

@dcaballe
Copy link
Contributor

The nice thing about --canonicalize though is that users of Vector (even those who don't know they are) don't need to know about every optimization written in the past, and add it to their pipeline to get better code.

This is exactly the situation we are trying to improve! There have been numerous complaints along the years about the arbitrary set of disconnected patterns that we have in the Vector dialect and how difficult is for someone new (and I would say for someone with experience starting a new compiler) to put all of them together and in the right order to produce an optimized output. I’m ok with having different passes that are able to convert one form to another but we need a common starting point, which should be the canonical one.

There was even an RFC where @banach-space and @dcaballe both were present: https://discourse.llvm.org/t/improving-handling-of-unit-dimensions-in-the-vector-dialect/75216/29 .I don't see anybody followup on any of those discussions, so we cannot throw those conclusions away and proceed with something a lot of reviewers disagreed with.

That RFC is from 2023. Quite a few things have changed since then. vector.shape_cast was considered experimental at the time. It didn’t support N->M reshapes and the lowering was extremely inefficient. Currently, though, vector.shape_cast semantics have been extended to support N-M reshapes, the lowering has been significantly improved and, most importantly, the op has been widely adopted by a few transformations, including vector linearization, dropping unit dimensions and vector reshape transformations, in general. I’d say part of this has naturally happened as part of the development process over time so if there are concerns to go against this and continue in the representational limbo we are currently in, that should have to be strongly justified technically in the current context. At this particular point in the discussion, I haven’t been able to understand the difference in semantics between:

%0 = vector.transpose %a, [1, 0] : vector<4x1xf32> to vector<1x4xf32>

and

%0 = vector.shape_cast %a : vector<4x1xf32> to vector<1x4xf32>

The RFC referenced above is also a good example of what can happen when we pull too many voices into a discussion. The discussion derailed and there were people talking over each other with counter-proposals and no conclusion was reached. I know not everyone agrees with this but the Tensor WG was built with the intent to improve on this situation and help build certain level of consensus before having a broader discussion. IMO, we should start there. Hopefully @rengolin or @rolfmorel can help organize. I think some people are on vacation right now so we may have to wait for a few days/weeks.

@banach-space
Copy link
Contributor

Returning after some OOO + sick leave. I wanted to clarify my earlier feedback.

I'm sorry, but this response is a bit confusing to me. I'm very aware of the previous prs by James, because James and I are aligned on what we are working towards.

On context: @Groverkss ,I wasn’t suggesting you weren’t aware of previous PRs. My intent was to provide context for other reviewers (including, but not limited to you) so they could see that this work is part of a longer effort, not an isolated change. And where my support was coming from.

In fact, i'm more confused that we talk about context here, given that the same change was reverted in the past and I distinctly remember we never came to conclusion what is the right form here: #72918 . I do not see that we resolved any of the concerns in that discussion, so I don't think concluding this is beneficial is correct, without previous context.

On the 2023 concerns: The PR you linked was ~2 years ago. It’s not clear to me whether the issues raised there still apply. If they do, could someone summarize them for us here? The earlier discussions:

involved many people and spanned months, but there’s been little follow-up since. That suggests (to me) that the concerns may no longer be relevant, especially given the changes in the ecosystem that Diego pointed out.

The RFC referenced above is also a good example of what can happen when we pull too many voices into a discussion.
(...)
I know not everyone agrees with this but the Tensor WG was built with the intent to improve on this situation and help build certain level of consensus before having a broader discussion.

On process: I share the worry that RFCs can stall progress. We’ve seen this before (e.g. the 0-D vectors RFC that took 6 months and required an ODM). That said, I’m not against RFCs, just cautious that we don’t end up in the same situation again. The Tensor WG seems like a good venue to first align and then escalate if needed.

Next step: If anyone believes the 2023 concerns still block this direction, could you restate them here? That would help us decide whether an RFC is warranted, or whether we can move forward with this canonicalization.

In the meantime, I added this to the Tensor WG agenda.

Thank you,
-Andrzej

@banach-space
Copy link
Contributor

The Tensor WG seems like a good venue to first align and then escalate if needed.

While this was on the agenda for our call today, we didn't get a chance to discuss it (notes from today: https://discourse.llvm.org/t/tcdg-tensor-compiler-design-group-meeting-notes-2025-09-17) - we prioritised the unrolling vs flattening topic that was raised earlier.

We'll get back to this in the next Tensor WG meeting!

@banach-space
Copy link
Contributor

We'll get back to this in the next Tensor WG meeting!

While we lacked Vector quorum, we did discuss this in our TCWG meeting on 15/10/2025 (see notes here). The focus was on unblocking the discussion.

The overall sentiment was that we should proceed with this unless specific and actionable concerns are raised. Regarding the earlier discussion on Discourse, much has changed since then, and no concrete follow-ups occurred. Given the ~2-year gap, it’s reasonable to assume that the previous concerns no longer hold.

Previously, vector.transpose was the most contentious part of the discussion. Today, mechanisms like isOrderPreserving
(introduced by @newling in #135841) provide the necessary safeguards to prevent "problematic" conversion of vector.transpose to vector.shape_cast.

Given LLVM’s liberal revert policy, I propose that we land this and, if issues arise, revert. We’ve been converging toward this change for some time.

-Andrzej

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, the overall direction makes a lot of sense to me, massive thanks for working on this!


if (!isOrderPreserving(transpose)) {
return rewriter.notifyMatchFailure(
transpose, "not order preserving, so not semantically a 'copy'");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean "semantically" without "not"?

Suggested change
transpose, "not order preserving, so not semantically a 'copy'");
transpose, "not order preserving, so semantically a 'copy'");

/// to fold shape_cast into create_mask.
///
/// BEFORE:
/// to fold shape_cast into creatto a greater number of BEFORE:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something has gone wrong here :)

@Groverkss
Copy link
Member

We'll get back to this in the next Tensor WG meeting!

While we lacked Vector quorum, we did discuss this in our TCWG meeting on 15/10/2025 (see notes here). The focus was on unblocking the discussion.

The overall sentiment was that we should proceed with this unless specific and actionable concerns are raised. Regarding the earlier discussion on Discourse, much has changed since then, and no concrete follow-ups occurred. Given the ~2-year gap, it’s reasonable to assume that the previous concerns no longer hold.

Previously, vector.transpose was the most contentious part of the discussion. Today, mechanisms like isOrderPreserving (introduced by @newling in #135841) provide the necessary safeguards to prevent "problematic" conversion of vector.transpose to vector.shape_cast.

Given LLVM’s liberal revert policy, I propose that we land this and, if issues arise, revert. We’ve been converging toward this change for some time.

-Andrzej

Hi, I'll try to write up concerns by Monday next week and ask for an RFC on how those concerns are addressed. I have raised concerns about this, other people have raised concerns about this, and if you want to proceed please send a RFC. I am happy to summarise all problems with this change by Monday.

Landing this without a RFC even when requested would be against the LLVM policy.

Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change has had problems in the past with concerns raised from @MaheshRavishankar @kuhar @qedawkins and me. I have talked with them privately recently, and we still have the same concerns. I will summarise the previous concerns by Monday.

Can you please send a RFC summarising what has changed and how you plan to handle the previously raised concerns? Happy to send a summary of the concerns by Monday which you can use.

Please do not land this without sending a RFC and getting explicit approval on the RFC.

@Groverkss
Copy link
Member

I did write a summary of previous concerns today, but I'm talking to others who were on previous threads to make sure if the concerns are still valid (some of them don't look valid to me). I'll post an update tommorow, due to timezone differences I couldn't sync with some people from US timezone today.

@Groverkss
Copy link
Member

The main issue with this canonicalization comes down to the fact that it drops
structural information of operations, which goes against "structured code
generation".

The crux of the issue raised in the previous prs, reverts and discourse
discussion is that shape_cast is not a "structured" N-D transformation, while
transpose, broadcast and extract are.

This is visible when describing each of these operations based on how their
result dimensions map to their input dimensions:

shape_cast -> output_dims = delinearize(linearize(input_dims, basis=input_shape), basis=output_shape)
transpose -> output_dims = permute(input_dims)
broadcast -> output_dims = concat(new_dims, input_dims)
extract -> output_dims = drop_front(input_dims, num_dropped=old_rank-new_rank)

Note that transpose, broadcast and extract represent a projected permutation
transformation while shape_cast represents an arbritary reshape transformation
which is less restricted than a projected permutation.

When coming from tensor land via vectorization, the vector code produced
preserves this structural information which is very useful. There are no
arbritary reshapes with linearize/delinearize semantics, and dimensional
structure is preserved.

This PR tries to canonicalize different ops which can do different things but have
a single point where they are similar, to a single, op which has the least
structure preserving semantics of them all.

The reasonings I'm seeing in this PR are:

The point of having a canonical representation is to choose a unique
representation for a computation that may have multiple alternative
representations. That form doesn't have to be the "best" representation
(whatever that means: it doesn’t have to be the most efficient, the most
compact, the most aesthetically pleasing, the most restrictive…) It has to be
unique and hold all the information that the original form held.

The original Ops are mechanically recoverable in every case, so I don't see
any loss of information here.

I don't see canonicalizing towards a less restricted op to be a goal of
canonicalization at all:
https://mlir.llvm.org/docs/Canonicalization/#general-design

Infact, that is only just the initial bar for something to be considered as a
canonicalization:

Canonicalize shouldn’t lose the semantic of original operation: the original
information should always be recoverable from the transformed IR.

It's not the bar for something to be a canonicalization.

The goal that is relevant here, that I do see is:

The goal of canonicalization is to make subsequent analyses and optimizations
more effective. Therefore, performance improvements are not necessary for
canonicalization.

There is no justification in this PR that canonicalizing towards a single operation
makes subsequent analysis and optimizations more effective. Infact, it requires
transformations to now add an edge case for special shape_cast that
could be the special form it was looking for. For example:

Let's say I'm writing a transformation (a commonly written one), which folds
a transpose-like op into a permutation map.

transpose 1x2x2x8 -> 2x2x8x1
contract { indexing_maps = [#map] }

transform to

contract { indexing_maps = [#transposed_map] }

It is easy to write, with just matching operations that have transpose like
semantics.

shape_cast 8x8 -> 2x2x8x1
contract { indexing_maps = [#map] }

Not only does my transformation has to look for transpose-like ops, it has to
look for shape_cast's that are special cased to be transpose-like.

This is making subsequent transformations harder to apply, this is actively going
against the goal of canonicalizations.

In practice, we have number of patterns in SPIRV backend in our downstream
compiler (IREE), which do transformations that look for
transpose/broadcast/extract/insert ops and use their semantics to analyze
before doing transformations (They don't rely on them for correctness). These
patterns will not run anymore, because they all need to handle a special case
of shape_cast acting like other ops, even though the pipeline never introduced
it.

Some examples (from @kuhar):

https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp#L148
https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp#L509

I'm sure I could find more examples from upstream, but these are just from
the top of my head.

Keeping this in mind, I don't agree this transformation is a canonicalization
for the entire vector dialect. It's okay to add this as a transformation pass,
pattern methods, or anything, as long as it's opt-in and doesn't force every
user of vector dialect to use this.

If you think otherwise, please send a RFC and we can discuss further there. This
is a very wide reaching change and I would like you to address why it's okay to
throw semantic information like projected permutation semantics and ask
transformations to analyze it back.

The reason I ask for an RFC is that there are justifications given in this PR
which I have no meaningful way to measure, and it falls on the author to prove
that this transformation is a canonicalization:

We currently have 3+ ways to drop unit dimensions and perform vector reshapes
that don’t imply data movement in the virtual vector world and we are paying
the cost of maintaining all those redundant patterns. I think the big
misalignment here is that we expect the canonical form to be the ideal input
for all our use cases without doing any pre-processing and that is something
we need to change for the shake of the overall project health.

A simple example of why this isn't a good arguement is we also have linalg.contraction, linalg.generic and linalg.matmul in the same dialect, but we don't choose to canonicalize them to each other. We have 3+ ways of writing a matmul in linalg. We should be consistent with this argument and apply it everywhere if so.

What I would like to see in the RFC is:

  1. Show what "redundant" patterns we remove from having this canonicalization.
  2. Measure the impact of upstream patterns where every pattern looking for
    a transpose/broadcast/extract now has to also account for a shape_cast,
    because one edge case falls through to that.
  3. Show why this canonicalization aligns with the goals of canonicalizations.

@joker-eph
Copy link
Collaborator

joker-eph commented Oct 21, 2025

There is no justification in this PR that canonicalizing towards a single operation makes subsequent analysis and optimizations more effective.

"In general" having to match a single form for the same program is what makes subsequent analysis more effective: that is we can match a single "thing" instead of needing to match many. So this sentence as-is is a bit odd to me (but you're elaborating pretty well in the rest of your post, so don't take this as a rebuttal of your post as a whole but just a minor nitpick of this particular sentence).
That said it does not mean "one op", it just means "pick a direction and stick to it". For InstCombine for a while it was something like (from memory) "if a form has less number of operations or less SSA values in flight, then it's 'simpler'".
It's very valid to consider than one form is "more structured" than the other, but usually the conclusion is that this form is the more canonical one and the canonicalization should go the other direction. In then end: "one form is preferable to many" is the underlying mindset.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants