Skip to content

Conversation

@newling
Copy link
Contributor

@newling newling commented May 19, 2025

Discussions suggest that we should use shape_cast as a canonical form of broadcast/transpose/extract where possible (see #138777)

For example these can all be expressed as shape casts:

%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
%1 = vector.transpose %arg1, [1, 0] : vector<2x1xi8> to vector<1x2xi8>
%2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>

This PR adds canonicalizes to convert the above 3 examples to shape_casts.

I've added some more comments as review comments.

I'm happy to split this PR up and add the new patterns separately.

@newling newling changed the title [vector][mlir] Canonicalize to shape_cast where possible [wip][vector][mlir] Canonicalize to shape_cast where possible May 19, 2025
@newling newling changed the title [wip][vector][mlir] Canonicalize to shape_cast where possible [vector][mlir] Canonicalize to shape_cast where possible May 19, 2025
@newling newling force-pushed the canonicalize_to_shape_cast branch from d546ab3 to 29d41d8 Compare June 5, 2025 18:07
@newling newling force-pushed the canonicalize_to_shape_cast branch from 29d41d8 to f2e5417 Compare June 25, 2025 23:03
@github-actions
Copy link

github-actions bot commented Jun 25, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@banach-space I'm getting back to this PR. Peephole question: is this operation ok? i.e. is

vector.shape_cast %a vector<[4]x1xf32> to vector<1x[4]xf32>

an acceptable operation to have after running mlir-opt -arm-sme-vector-legalization -cse -canonicalize ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, yes. But I can't guarantee there's no logic that expects vector<[4]x1xf32> instead of vector<1x[4]xf32> ;-) If that's the case, we will fix it and I will be grateful for uncovering this :)

@newling newling force-pushed the canonicalize_to_shape_cast branch from 7bc5da0 to e673522 Compare June 26, 2025 15:31
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Author note: I've removed this, as now it happens in 2 steps during canonicalization. The first converts the Broadcast to a ShapeCast. The second combines the 2 ShapeCasts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Author note: I've removed this, as it now happens in 2 steps during canonicalization. The first (new) step is to rewrite the transpose as a shape_cast. The second step is to fold shape_cast(shape_cast) to shape_cast.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Author note: I've removed this pattern, as it is a special case of TransposeToShapeCast

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Author note: removed these tests, as the pattern they are testing is removed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we keep them? shouldn't they still be canonicalized?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add them back, yes they're still canonicalized

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Author note: as the vector.transpose is canonicalized to a vector.shape_cast, the lowering test is now moved to shape_cast lowering

@llvmbot
Copy link
Member

llvmbot commented Jun 26, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir-sme

Author: James Newling (newling)

Changes

Discussions suggest that we should use shape_cast as a canonical form of broadcast/transpose/extract where possible (see #138777)

For example these can all be expressed as shape casts:

%0 = vector.broadcast %arg0 : vector&lt;4xi8&gt; to vector&lt;1x1x4xi8&gt;
%1 = vector.transpose %arg1, [1, 0] : vector&lt;2x1xi8&gt; to vector&lt;1x2xi8&gt;
%2 = vector.extract %arg2[0] : vector&lt;4xi8&gt; from vector&lt;1x4xi8&gt;

This PR adds canonicalizes to convert the above 3 examples to shape_casts.

I've added some more comments as review comments.

I'm happy to split this PR up and add the new patterns separately.


Patch is 41.81 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140583.diff

10 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+84-53)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp (-61)
  • (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (+4-4)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+26-41)
  • (modified) mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir (+2-2)
  • (added) mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir (+162)
  • (modified) mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir (+60)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir (+5-7)
  • (modified) mlir/test/Dialect/Vector/vector-transpose-lowering.mlir (-85)
  • (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+4-4)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 862ed7bae1fbb..08cc4af158e10 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2351,11 +2351,41 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
   return success();
 }
 
+/// BEFORE:
+/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType sourceType = extractOp.getSourceVectorType();
+    VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+    if (!outType)
+      return failure();
+
+    // Negative values in `position` indicates poison, which cannot be
+    // represented with a shape_cast
+    if (llvm::any_of(extractOp.getMixedPosition(),
+                     [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
+      return failure();
+
+    if (sourceType.getNumElements() != outType.getNumElements())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
+                                                     extractOp.getVector());
+    return success();
+  }
+};
+
 } // namespace
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+  results
+      .add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
+          context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
 }
@@ -2867,13 +2897,36 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
     return success();
   }
 };
+
+/// BEFORE:
+/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+struct BroadcastToShapeCast final
+    : public OpRewritePattern<vector::BroadcastOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
+                                PatternRewriter &rewriter) const override {
+    auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
+    if (!sourceType) {
+      return rewriter.notifyMatchFailure(
+          broadcast, "source is a scalar, shape_cast doesn't support scalar");
+    }
+
+    VectorType outType = broadcast.getType();
+    if (sourceType.getNumElements() != outType.getNumElements())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
+                                                     broadcast.getSource());
+    return success();
+  }
+};
 } // namespace
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
-  // calling `populateCastAwayVectorLeadingOneDimPatterns`
-  results.add<BroadcastFolder>(context);
+  results.add<BroadcastFolder, BroadcastToShapeCast>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -5991,10 +6044,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
   }
 };
 
-/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
-///   i) Y = ShapeCast(X), or
-///  ii) Y = Broadcast(X)
-/// If both (i) and (ii) are possible, (i) is chosen.
+/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
@@ -6009,22 +6059,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
     auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
     bool srcIsScalar = !srcVectorType;
 
-    // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
-    // Example:
-    // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
-    // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
-    // to
-    // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
-    if (srcVectorType) {
-      if (srcVectorType.getNumElements() ==
-          shapeCastOp.getResultVectorType().getNumElements()) {
-        rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
-            shapeCastOp, shapeCastOp.getResultVectorType(),
-            broadcastOp.getSource());
-        return success();
-      }
-    }
-
     // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
     // Example
     // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
@@ -6233,7 +6267,7 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
   // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
   // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
   //
-  // Example of what NOT to fold:
+  // Example of what not to fold:
   // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
   //
   if (getSourceVectorType() == getResultVectorType() &&
@@ -6359,32 +6393,6 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
   }
 };
 
-/// Folds transpose(shape_cast) into a new shape_cast.
-class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(TransposeOp transposeOp,
-                                PatternRewriter &rewriter) const override {
-    auto shapeCastOp =
-        transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
-    if (!shapeCastOp)
-      return failure();
-    if (!isOrderPreserving(transposeOp))
-      return failure();
-
-    VectorType resultType = transposeOp.getType();
-
-    // We don't need to check isValidShapeCast at this point, because it is
-    // guaranteed that merging the transpose into the the shape_cast is a valid
-    // shape_cast, because the transpose just inserts/removes ones.
-
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
-                                                     shapeCastOp.getSource());
-    return success();
-  }
-};
-
 /// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
 /// 'order preserving', where 'order preserving' means the flattened
 /// inputs and outputs of the transpose have identical (numerical) values.
@@ -6480,12 +6488,35 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
   }
 };
 
+/// BEFORE:
+/// %0 = vector.transpose %arg0, [0, 2, 1] :
+///                   vector<2x1x2xf32> to vector<2x2x1xf32>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 :
+///                   vector<2x1x2xf32> to vector<2x2x1xf32>
+struct TransposeToShapeCast final
+    : public OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::TransposeOp transpose,
+                                PatternRewriter &rewriter) const override {
+
+    if (!isOrderPreserving(transpose)) {
+      return rewriter.notifyMatchFailure(
+          transpose, "not order preserving, so not semantically a 'copy'");
+    }
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+        transpose, transpose.getType(), transpose.getVector());
+    return success();
+  }
+};
+
 } // namespace
 
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
-  results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
-              FoldTransposeSplat, FoldTransposeBroadcast>(context);
+  results.add<FoldTransposeBroadcast, FoldTransposeCreateMask,
+              FoldTransposeSplat, TransposeFolder, TransposeToShapeCast>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 732e316c93381..71410eda28297 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -11,7 +11,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -382,64 +381,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
   vector::VectorTransposeLowering vectorTransposeLowering;
 };
 
-/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
-/// to 2D vectors with at least one unit dim. For example:
-///
-/// Replace:
-///   vector.transpose %0, [1, 0] : vector<4x1xi32>> to
-///                                 vector<1x4xi32>
-/// with:
-///   vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
-///
-/// Source with leading unit dim (inverse) is also replaced. Unit dim must
-/// be fixed. Non-unit dim can be scalable.
-///
-/// TODO: This pattern was introduced specifically to help lower scalable
-/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
-/// to cancel out) would be preferable:
-///
-///  BEFORE:
-///     %0 = some_op
-///     %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
-///     %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-///  AFTER:
-///     %0 = some_op
-///     %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
-///
-/// Given the context above, we may want to consider (re-)moving this pattern
-/// at some later time. I am leaving it for now in case there are other users
-/// that I am not aware of.
-class Transpose2DWithUnitDimToShapeCast
-    : public OpRewritePattern<vector::TransposeOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
-                                    PatternBenefit benefit = 1)
-      : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
-
-  LogicalResult matchAndRewrite(vector::TransposeOp op,
-                                PatternRewriter &rewriter) const override {
-    Value input = op.getVector();
-    VectorType resType = op.getResultVectorType();
-
-    // Set up convenience transposition table.
-    ArrayRef<int64_t> transp = op.getPermutation();
-
-    if (resType.getRank() == 2 &&
-        ((resType.getShape().front() == 1 &&
-          !resType.getScalableDims().front()) ||
-         (resType.getShape().back() == 1 &&
-          !resType.getScalableDims().back())) &&
-        transp == ArrayRef<int64_t>({1, 0})) {
-      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
-      return success();
-    }
-
-    return failure();
-  }
-};
-
 /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
 /// If the strategy is Shuffle1D, it will be lowered to:
 ///   vector.shape_cast 2D -> 1D
@@ -511,8 +452,6 @@ class TransposeOp2DToShuffleLowering
 void mlir::vector::populateVectorTransposeLoweringPatterns(
     RewritePatternSet &patterns,
     VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) {
-  patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
-                                                  benefit);
   patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
       vectorTransposeLowering, patterns.getContext(), benefit);
 }
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index 6cdf576272ebc..a9a2fdccdd82f 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -480,11 +480,11 @@ func.func @lift_illegal_transpose_to_memory_with_in_bounds_attr(%a: index, %b: i
 
 // -----
 
-// The pass should do nothing (and not crash).
-// CHECK-LABEL: @illegal_transpose_no_defining_source_op
-func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32>
+// CHECK-LABEL: @transpose_no_defining_source_op
+func.func @transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32>
 {
-  // CHECK: vector.transpose
+  // CHECK:      vector.shape_cast
+  // CHECK-SAME: vector<[4]x1xf32> to vector<1x[4]xf32>
   %0 = vector.transpose %vec, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
   return %0 : vector<1x[4]xf32>
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 65b73375831da..374c71c814e89 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -451,16 +451,25 @@ func.func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>,
 // -----
 
 // CHECK-LABEL: transpose_3D_identity
-// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
+//  CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
+//  CHECK-NEXT: return [[ARG]]
 func.func @transpose_3D_identity(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {
-  // CHECK-NOT: transpose
   %0 = vector.transpose %arg, [0, 1, 2] : vector<4x3x2xf32> to vector<4x3x2xf32>
-  // CHECK-NEXT: return [[ARG]]
   return %0 : vector<4x3x2xf32>
 }
 
 // -----
 
+// CHECK-LABEL: transpose_0D_identity
+//  CHECK-SAME: ([[ARG:%.*]]: vector<i8>)
+//  CHECK-NEXT: return [[ARG]]
+func.func @transpose_0D_identity(%arg : vector<i8>) -> vector<i8> {
+  %0 = vector.transpose %arg, [] : vector<i8> to vector<i8>
+  return %0 : vector<i8>
+}
+
+// -----
+
 // CHECK-LABEL: transpose_2D_sequence
 // CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>)
 func.func @transpose_2D_sequence(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
@@ -753,12 +762,13 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
 
 // -----
 
+
 // CHECK-LABEL: negative_fold_extract_broadcast
-//       CHECK:   vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
-//       CHECK:   vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+//       CHECK:   vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x2x4xf32>
+//       CHECK:   vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x2x4xf32>
 func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> {
-  %b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32>
-  %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+  %b = vector.broadcast %a : vector<1x1xf32> to vector<1x2x4xf32>
+  %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x2x4xf32>
   return %r : vector<4xf32>
 }
 
@@ -797,8 +807,8 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
 // rank(extract_output) < rank(broadcast_input)
 func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
   %idx0 : index, %idx1 : index) -> vector<4xf32> {
-  %b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
-  %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
+  %b = vector.broadcast %a : vector<2x4xf32> to vector<2x2x4xf32>
+  %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<2x2x4xf32>
   return %r : vector<4xf32>
 }
 
@@ -1033,30 +1043,6 @@ func.func @canonicalize_broadcast_shapecast_to_broadcast_scalar(%arg0: f32) -> v
 
 // -----
 
-// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is.
-// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapcast
-//   CHECK-NOT:   vector.broadcast
-//       CHECK:   vector.shape_cast {{.+}} : vector<2xf32> to vector<1x2x1xf32>
-func.func @canonicalize_broadcast_shapecast_to_shapcast(%arg0 : vector<2xf32>) -> vector<1x2x1xf32> {
-  %0 = vector.broadcast %arg0 : vector<2xf32> to vector<1x2xf32>
-  %1 = vector.shape_cast %0 : vector<1x2xf32> to vector<1x2x1xf32>
-  return %1 : vector<1x2x1xf32>
-}
-
-// -----
-
-// In this test, broadcast (1)->(1,1) and shape_cast (1)->(1,1) are both legal. shape_cast is chosen.
-// CHECK-LABEL: func @canonicalize_broadcast_shapecast_both_possible
-//   CHECK-NOT:   vector.broadcast
-//       CHECK:   vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32>
-func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) -> vector<1x1xf32> {
-    %0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32>
-    %1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32>
-    return %1 : vector<1x1xf32>
-}
-
-// -----
-
 // CHECK-LABEL: fold_vector_transfer_masks
 func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -1920,12 +1906,12 @@ func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
 
 // -----
 
-// CHECK-LABEL: func @insert_extract_to_broadcast
+// CHECK-LABEL: func @insert_extract_to_shape_cast
 //  CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-//       CHECK:   %[[V0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
-//       CHECK:   %[[V1:.*]] = vector.broadcast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
+//       CHECK:   %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
+//       CHECK:   %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
 //       CHECK:   return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
-func.func @insert_extract_to_broadcast(%arg0 : vector<1x1x4xf32>,
+func.func @insert_extract_to_shape_cast(%arg0 : vector<1x1x4xf32>,
   %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
   %0 = vector.extract %arg0[0, 0] : vector<4xf32> from vector<1x1x4xf32>
   %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
@@ -2277,7 +2263,7 @@ func.func @shuffle_1d_rhs_poison() -> vector<4xi32> {
 
 // CHECK-LABEL: func @shuffle_canonicalize_0d
 func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
-  // CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
+  // CHECK: vector.shape_cast %{{.*}} : vector<i32> to vector<1xi32>
   %shuffle = vector.shuffle %v0, %v1 [0] : vector<i32>, vector<i32>
   return %shuffle : vector<1xi32>
 }
@@ -2764,9 +2750,8 @@ func.func @transfer_read_from_rank_reducing_extract_slice(%src: tensor<1x8x8x8xf
 // CHECK-LABEL: func.func @extract_from_broadcast
 func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
   %0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32>
-
-  //  CHECK-NEXT:   %0 = vector.extract {{.*}}[0, 0] : vector<1xf32> from vector<1x1x1xf32>
-  //  CHECK-NEXT:   return %0 : vector<1xf32>
+  //  CHECK-NEXT:   %[[RES:.*]] = vector.shape_cast{{.*}} vector<1x1x1xf32> to vector<1xf32>
+  //  CHECK-NEXT:   return %[[RES]] : vector<1xf32>
   %1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32>
   return %1: vector<1xf32>
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index fdab2a8918a2e..d5f96a8928770 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -81,8 +81,8 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
 
 // CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
 //  CHECK-SAME:       %[[A:.*]]: vector<1x2xi8>)
-//       CHECK:       %[[EXTRACT:.*]] = vector.extract %[[A]][0] : vector<2xi8> from vector<1x2xi8>
-//       CHECK:       return %[[EXTRACT]] : vector<2xi8>
+//       CHECK:       %[[SC:.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8>
+//       CHECK:       return %[[SC]] : vector<2xi8>
 func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
   %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
   %1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
diff --git...
[truncated]

@newling
Copy link
Contributor Author

newling commented Jun 26, 2025

Hi @banach-space and @dcaballe, I've pulled this PR out of draft mode, so please feel free to comment on it whenever!

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! LGTM in general. The only general comment is to make sure we don't reduce testing coverage. I think we should keep/update the tests even for those cases where the pattern is removed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep both tests, one with the original shape and one with the new ones?

Unrelated: it looks like we are missing a canonicalization patter here? This should be turned into a single vector.broadcast to vector<4xf32>?

Copy link
Contributor Author

@newling newling Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep both tests, one with the original shape and one with the new ones?

Makes sense, will do.

Unrelated: it looks like we are missing a canonicalization patter here? This should be turned into a single vector.broadcast to vector<4xf32>?

No because you can't broadcast <1x1xf32> to <4xf32> -- broadcasts can never reduce rank in Vector. FWIW slightly related to my comment here where this would be simpler if ops didn't do implicit shape casting. In this case if it was something like

 %s = vector.shape_cast %a : vector<1x1xf32> to vector<1x1x1xf32>
 %b = vector.broadcast %s : vector<1x1x1xf32> to vector<1x2x4xf32>
 %r = vector.extract %b[0, 0] : vector<1x1x4xf32> from vector<1x2x4xf32>
 %s = vector.shape_cast %r : vector<1x1x4> to vector<4>

ie if we constrained broadcasts and extracts to be rank retaining, then this would be canonicalized to

 %s = vector.shape_cast %a : vector<1x1xf32> to vector<1x1x1xf32>
 %b = vector.broadcast %s : vector<1x1x1xf32> to vector<1x1x4xf32>
 %s = vector.shape_cast %b : vector<1x1x4> to vector<4>

which, if you have faith that the shape_casts will vanish at a later point, is simpler!

p.s. I plan to reply in #145740 later today

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we keep them? shouldn't they still be canonicalized?

@banach-space
Copy link
Contributor

Thanks!

I run the SME e2e tests and all pass. I wasn't able to cherry-pick this in IREE though, getting weird compilation errors. Though upstream tests should be sufficient to surface all potential issues.

@newling , why not name all "folding" patterns as fold...? Wouldn't that be more consistent?

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Comment on lines 810 to 811
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change shapes?

@newling
Copy link
Contributor Author

newling commented Jun 27, 2025

Thanks!

I run the SME e2e tests and all pass. I wasn't able to cherry-pick this in IREE though, getting weird compilation errors. Though upstream tests should be sufficient to surface all potential issues.

@newling , why not name all "folding" patterns as fold...? Wouldn't that be more consistent?

I'll give this a spin with IREE

@newling , why not name all "folding" patterns as fold...? Wouldn't that be more consistent?

Yes, I think so. Actually fe3933d made me wonder if we should split canonicalize.mlir into 2 files (the new one with name fold.mlir containing everything in canonicalize.mlir that only depends on 1-time folds).

@banach-space and @dcaballe thanks for your feedback! Unfortunately I'm going to put this on hold again temporarily, as I've uncovered some other things which should be done before this. Moving back into draft mode, will ping when I think it's ready again.

@newling newling marked this pull request as draft June 27, 2025 00:51
@banach-space
Copy link
Contributor

Actually fe3933d made me wonder if we should split canonicalize.mlir into 2 files (the new one with name fold.mlir containing everything in canonicalize.mlir that only depends on 1-time folds).

+1

@newling
Copy link
Contributor Author

newling commented Aug 6, 2025

This PR is back, and ready for review!

Let me summarize the previous concerns as this is quite old now:

@dcaballe raised concerns about removing tests. I have reinstated all canonicalization tests.
@banach-space raised concerns about test naming and dir structure. I would prefer to address these in a later PR as part of a wider canonicalization/folding test refactor.
@banach-space noted this would likely cause ripples downstream, and suggested running IREE tests. I have done this, and indeed some lit tests fail. I will take responsibility for fixing these (FYI @Groverkss).

@banach-space
Copy link
Contributor

What is this PR achieving?

to me something seems wrong, that we are moving towards shape_cast being the canonical form for removing unit dimensions

To me, this change isn't really about removing unit dims. Its about improving the state of Vector.

In the distant past (~4yrs ago) we did use shape_cast to remove unit dims and then effectively reverted that to restrict its usage. Sadly, that "experiment" didn't really work as intended - we do use vector.shape_cast a lot.

Without a crisp contract, vector.shape_cast remains a foot-gun. To me, this change is more about:

  • containing vector.shape_cast by refining + clarifying it's purpose + semantics.

Context + previous PRs

Note, James has already proposed/implemented a number of steps to improve shape_cast:

This PR simply takes things further (and was discussed before: #138777). I just wanted to highlight - there has been a number of PRs through which we concluded that this would be beneficial.

Why take this approach?

The point of having a canonical representation is to choose a unique representation for a computation that may have multiple alternative representations

Indeed! With multiple ways of expressing identical computation, lets just agree on one. I don't see any information being lost, so to me this a matter of simply agreeing on what the canonical form should be and then working with that.

From #138777:

Currently the vector dialect has no strict specification of which of 2 equivalent forms is more canonical

Yes! Clarifying this takes us to a cleaner, better defined state - that's beneficial for Vector.

As a specific argument for this change, replacing:

%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>

with a vector.shape_cast adds a useful hint: this broadcast is shape-only (no movement), and the op class itself now encodes “view-change”.

Is this really a canonicalization?

Yes.

The original Ops are mechanically recoverable in every case, so I don't see any loss of information here.

This is just about having folder for vector.shape_cast(vector.extract) -> vector.extract, right?

I believe so, yes.

Do we need an RFC?

I probably should have done this as an RFC on discourse

We are discussing fine implementation details within Vector. Since all the maintainers of Vector are already here, it's not clear to me whether taking this to Discourse would help. I’d prefer to keep momentum here unless there’s broader impact beyond vector.

LGTM?

Yes

@newling
Copy link
Contributor Author

newling commented Aug 14, 2025

I don't mind which way this PR goes. The original goal of the PR was to generate better code in a simpler way, not to define a canonical form. Converting to shape_cast achieves the original goal, but we don't have to say that shape_cast is a canonical form.

As already mentioned in previous comments, technically it is a canonicalization if we want it to be, but until we formalize a canonical form (ideally with an energy function with guarantees of converging to a fixed point, and clear documentation and motivation in a single obvious place) maybe we're better off leaving things that don't directly reduce op count or provide some other obvious generic improvement, out. As someone who's tried to unravel complex decisions make 4+ years ago in this dialect (and often not followed through on) I'd rather not leave a whole new paper trail for future developers.

What I've actually been considering for a while is making a 2-step pass that

  • first makes as much as possible a shape_cast. Included here is stripping out the implicit shape_cast before/after broadcasts/from_elements/gather/extract, etc.
  • let shape_casts cancel.
  • make as little as possible a shape_cast (i.e. shape_cast -> transpose, and reabsorb shape_casts into broadcasts/extracts).

I think that would be quite powerful. The nice thing about --canonicalize though is that users of Vector (even those who don't know they are) don't need to know about every optimization written in the past, and add it to their pipeline to get better code. As canonicalize is a 1-step pass, it can't include what I've outlined above. Also, canonicalize must be fast.

@Groverkss
Copy link
Member

Groverkss commented Aug 15, 2025

Note, James has already proposed/implemented a number of steps to improve shape_cast:

Lowering ([mlir][vector] Improve shape_cast lowering #140800)
Folding (#146368 + #135841)
This PR simply takes things further (and was discussed before: #138777). I just > wanted to highlight - there has been a number of PRs through which we concluded that this would be beneficial.

I'm sorry, but this response is a bit confusing to me. I'm very aware of the previous prs by James, because James and I are aligned on what we are working towards. I didn't chime in on those prs because I didn't want two people with the same ideas , from the same team and working on similar things to just review things between them and land things, given that we have multiple maintainers for Vector dialect. I do skim by those prs, I just didn't read until the title of this PR that we came to a conclusion that these patterns should be canonicalizations.

In fact, i'm more confused that we talk about context here, given that the same change was reverted in the past and I distinctly remember we never came to conclusion what is the right form here: #72918 . I do not see that we resolved any of the concerns in that discussion, so I don't think concluding this is beneficial is correct, without previous context. There was even an RFC where @banach-space and @dcaballe both were present: https://discourse.llvm.org/t/improving-handling-of-unit-dimensions-in-the-vector-dialect/75216/29 .I don't see anybody followup on any of those discussions, so we cannot throw those conclusions away and proceed with something a lot of reviewers disagreed with.

Do we need an RFC?

Given that there was already an RFC about this, with no followup on things, yes we need an RFC on this. A lot of vector dialect users were present in that discussion whose concerns need to be addressed. I'm not okay with this proceeding as a PR given that there was an RFC about this where this PR was not the conclusion.

In the distant past (~4yrs ago) we did use shape_cast to remove unit dims and then effectively reverted that to restrict its usage. Sadly, that "experiment" didn't really work as intended - we do use vector.shape_cast a lot.

I don't know when we concluded that this "experiment" failed. I see that in that RFC @nicolasvasilache had the same point about vector.shape_cast https://discourse.llvm.org/t/improving-handling-of-unit-dimensions-in-the-vector-dialect/75216/10 . This is why we need to address the points in the RFC before we start a discussion on these patterns being canonicalizations.

@dcaballe
Copy link
Contributor

The nice thing about --canonicalize though is that users of Vector (even those who don't know they are) don't need to know about every optimization written in the past, and add it to their pipeline to get better code.

This is exactly the situation we are trying to improve! There have been numerous complaints along the years about the arbitrary set of disconnected patterns that we have in the Vector dialect and how difficult is for someone new (and I would say for someone with experience starting a new compiler) to put all of them together and in the right order to produce an optimized output. I’m ok with having different passes that are able to convert one form to another but we need a common starting point, which should be the canonical one.

There was even an RFC where @banach-space and @dcaballe both were present: https://discourse.llvm.org/t/improving-handling-of-unit-dimensions-in-the-vector-dialect/75216/29 .I don't see anybody followup on any of those discussions, so we cannot throw those conclusions away and proceed with something a lot of reviewers disagreed with.

That RFC is from 2023. Quite a few things have changed since then. vector.shape_cast was considered experimental at the time. It didn’t support N->M reshapes and the lowering was extremely inefficient. Currently, though, vector.shape_cast semantics have been extended to support N-M reshapes, the lowering has been significantly improved and, most importantly, the op has been widely adopted by a few transformations, including vector linearization, dropping unit dimensions and vector reshape transformations, in general. I’d say part of this has naturally happened as part of the development process over time so if there are concerns to go against this and continue in the representational limbo we are currently in, that should have to be strongly justified technically in the current context. At this particular point in the discussion, I haven’t been able to understand the difference in semantics between:

%0 = vector.transpose %a, [1, 0] : vector<4x1xf32> to vector<1x4xf32>

and

%0 = vector.shape_cast %a : vector<4x1xf32> to vector<1x4xf32>

The RFC referenced above is also a good example of what can happen when we pull too many voices into a discussion. The discussion derailed and there were people talking over each other with counter-proposals and no conclusion was reached. I know not everyone agrees with this but the Tensor WG was built with the intent to improve on this situation and help build certain level of consensus before having a broader discussion. IMO, we should start there. Hopefully @rengolin or @rolfmorel can help organize. I think some people are on vacation right now so we may have to wait for a few days/weeks.

@banach-space
Copy link
Contributor

Returning after some OOO + sick leave. I wanted to clarify my earlier feedback.

I'm sorry, but this response is a bit confusing to me. I'm very aware of the previous prs by James, because James and I are aligned on what we are working towards.

On context: @Groverkss ,I wasn’t suggesting you weren’t aware of previous PRs. My intent was to provide context for other reviewers (including, but not limited to you) so they could see that this work is part of a longer effort, not an isolated change. And where my support was coming from.

In fact, i'm more confused that we talk about context here, given that the same change was reverted in the past and I distinctly remember we never came to conclusion what is the right form here: #72918 . I do not see that we resolved any of the concerns in that discussion, so I don't think concluding this is beneficial is correct, without previous context.

On the 2023 concerns: The PR you linked was ~2 years ago. It’s not clear to me whether the issues raised there still apply. If they do, could someone summarize them for us here? The earlier discussions:

involved many people and spanned months, but there’s been little follow-up since. That suggests (to me) that the concerns may no longer be relevant, especially given the changes in the ecosystem that Diego pointed out.

The RFC referenced above is also a good example of what can happen when we pull too many voices into a discussion.
(...)
I know not everyone agrees with this but the Tensor WG was built with the intent to improve on this situation and help build certain level of consensus before having a broader discussion.

On process: I share the worry that RFCs can stall progress. We’ve seen this before (e.g. the 0-D vectors RFC that took 6 months and required an ODM). That said, I’m not against RFCs, just cautious that we don’t end up in the same situation again. The Tensor WG seems like a good venue to first align and then escalate if needed.

Next step: If anyone believes the 2023 concerns still block this direction, could you restate them here? That would help us decide whether an RFC is warranted, or whether we can move forward with this canonicalization.

In the meantime, I added this to the Tensor WG agenda.

Thank you,
-Andrzej

@banach-space
Copy link
Contributor

The Tensor WG seems like a good venue to first align and then escalate if needed.

While this was on the agenda for our call today, we didn't get a chance to discuss it (notes from today: https://discourse.llvm.org/t/tcdg-tensor-compiler-design-group-meeting-notes-2025-09-17) - we prioritised the unrolling vs flattening topic that was raised earlier.

We'll get back to this in the next Tensor WG meeting!

@banach-space
Copy link
Contributor

We'll get back to this in the next Tensor WG meeting!

While we lacked Vector quorum, we did discuss this in our TCWG meeting on 15/10/2025 (see notes here). The focus was on unblocking the discussion.

The overall sentiment was that we should proceed with this unless specific and actionable concerns are raised. Regarding the earlier discussion on Discourse, much has changed since then, and no concrete follow-ups occurred. Given the ~2-year gap, it’s reasonable to assume that the previous concerns no longer hold.

Previously, vector.transpose was the most contentious part of the discussion. Today, mechanisms like isOrderPreserving
(introduced by @newling in #135841) provide the necessary safeguards to prevent "problematic" conversion of vector.transpose to vector.shape_cast.

Given LLVM’s liberal revert policy, I propose that we land this and, if issues arise, revert. We’ve been converging toward this change for some time.

-Andrzej

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, the overall direction makes a lot of sense to me, massive thanks for working on this!


if (!isOrderPreserving(transpose)) {
return rewriter.notifyMatchFailure(
transpose, "not order preserving, so not semantically a 'copy'");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean "semantically" without "not"?

Suggested change
transpose, "not order preserving, so not semantically a 'copy'");
transpose, "not order preserving, so semantically a 'copy'");

/// to fold shape_cast into create_mask.
///
/// BEFORE:
/// to fold shape_cast into creatto a greater number of BEFORE:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something has gone wrong here :)

@Groverkss
Copy link
Member

We'll get back to this in the next Tensor WG meeting!

While we lacked Vector quorum, we did discuss this in our TCWG meeting on 15/10/2025 (see notes here). The focus was on unblocking the discussion.

The overall sentiment was that we should proceed with this unless specific and actionable concerns are raised. Regarding the earlier discussion on Discourse, much has changed since then, and no concrete follow-ups occurred. Given the ~2-year gap, it’s reasonable to assume that the previous concerns no longer hold.

Previously, vector.transpose was the most contentious part of the discussion. Today, mechanisms like isOrderPreserving (introduced by @newling in #135841) provide the necessary safeguards to prevent "problematic" conversion of vector.transpose to vector.shape_cast.

Given LLVM’s liberal revert policy, I propose that we land this and, if issues arise, revert. We’ve been converging toward this change for some time.

-Andrzej

Hi, I'll try to write up concerns by Monday next week and ask for an RFC on how those concerns are addressed. I have raised concerns about this, other people have raised concerns about this, and if you want to proceed please send a RFC. I am happy to summarise all problems with this change by Monday.

Landing this without a RFC even when requested would be against the LLVM policy.

Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change has had problems in the past with concerns raised from @MaheshRavishankar @kuhar @qedawkins and me. I have talked with them privately recently, and we still have the same concerns. I will summarise the previous concerns by Monday.

Can you please send a RFC summarising what has changed and how you plan to handle the previously raised concerns? Happy to send a summary of the concerns by Monday which you can use.

Please do not land this without sending a RFC and getting explicit approval on the RFC.

@Groverkss
Copy link
Member

I did write a summary of previous concerns today, but I'm talking to others who were on previous threads to make sure if the concerns are still valid (some of them don't look valid to me). I'll post an update tommorow, due to timezone differences I couldn't sync with some people from US timezone today.

@Groverkss
Copy link
Member

The main issue with this canonicalization comes down to the fact that it drops
structural information of operations, which goes against "structured code
generation".

The crux of the issue raised in the previous prs, reverts and discourse
discussion is that shape_cast is not a "structured" N-D transformation, while
transpose, broadcast and extract are.

This is visible when describing each of these operations based on how their
result dimensions map to their input dimensions:

shape_cast -> output_dims = delinearize(linearize(input_dims, basis=input_shape), basis=output_shape)
transpose -> output_dims = permute(input_dims)
broadcast -> output_dims = concat(new_dims, input_dims)
extract -> output_dims = drop_front(input_dims, num_dropped=old_rank-new_rank)

Note that transpose, broadcast and extract represent a projected permutation
transformation while shape_cast represents an arbritary reshape transformation
which is less restricted than a projected permutation.

When coming from tensor land via vectorization, the vector code produced
preserves this structural information which is very useful. There are no
arbritary reshapes with linearize/delinearize semantics, and dimensional
structure is preserved.

This PR tries to canonicalize different ops which can do different things but have
a single point where they are similar, to a single, op which has the least
structure preserving semantics of them all.

The reasonings I'm seeing in this PR are:

The point of having a canonical representation is to choose a unique
representation for a computation that may have multiple alternative
representations. That form doesn't have to be the "best" representation
(whatever that means: it doesn’t have to be the most efficient, the most
compact, the most aesthetically pleasing, the most restrictive…) It has to be
unique and hold all the information that the original form held.

The original Ops are mechanically recoverable in every case, so I don't see
any loss of information here.

I don't see canonicalizing towards a less restricted op to be a goal of
canonicalization at all:
https://mlir.llvm.org/docs/Canonicalization/#general-design

Infact, that is only just the initial bar for something to be considered as a
canonicalization:

Canonicalize shouldn’t lose the semantic of original operation: the original
information should always be recoverable from the transformed IR.

It's not the bar for something to be a canonicalization.

The goal that is relevant here, that I do see is:

The goal of canonicalization is to make subsequent analyses and optimizations
more effective. Therefore, performance improvements are not necessary for
canonicalization.

There is no justification in this PR that canonicalizing towards a single operation
makes subsequent analysis and optimizations more effective. Infact, it requires
transformations to now add an edge case for special shape_cast that
could be the special form it was looking for. For example:

Let's say I'm writing a transformation (a commonly written one), which folds
a transpose-like op into a permutation map.

transpose 1x2x2x8 -> 2x2x8x1
contract { indexing_maps = [#map] }

transform to

contract { indexing_maps = [#transposed_map] }

It is easy to write, with just matching operations that have transpose like
semantics.

shape_cast 8x8 -> 2x2x8x1
contract { indexing_maps = [#map] }

Not only does my transformation has to look for transpose-like ops, it has to
look for shape_cast's that are special cased to be transpose-like.

This is making subsequent transformations harder to apply, this is actively going
against the goal of canonicalizations.

In practice, we have number of patterns in SPIRV backend in our downstream
compiler (IREE), which do transformations that look for
transpose/broadcast/extract/insert ops and use their semantics to analyze
before doing transformations (They don't rely on them for correctness). These
patterns will not run anymore, because they all need to handle a special case
of shape_cast acting like other ops, even though the pipeline never introduced
it.

Some examples (from @kuhar):

https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp#L148
https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp#L509

I'm sure I could find more examples from upstream, but these are just from
the top of my head.

Keeping this in mind, I don't agree this transformation is a canonicalization
for the entire vector dialect. It's okay to add this as a transformation pass,
pattern methods, or anything, as long as it's opt-in and doesn't force every
user of vector dialect to use this.

If you think otherwise, please send a RFC and we can discuss further there. This
is a very wide reaching change and I would like you to address why it's okay to
throw semantic information like projected permutation semantics and ask
transformations to analyze it back.

The reason I ask for an RFC is that there are justifications given in this PR
which I have no meaningful way to measure, and it falls on the author to prove
that this transformation is a canonicalization:

We currently have 3+ ways to drop unit dimensions and perform vector reshapes
that don’t imply data movement in the virtual vector world and we are paying
the cost of maintaining all those redundant patterns. I think the big
misalignment here is that we expect the canonical form to be the ideal input
for all our use cases without doing any pre-processing and that is something
we need to change for the shake of the overall project health.

A simple example of why this isn't a good arguement is we also have linalg.contraction, linalg.generic and linalg.matmul in the same dialect, but we don't choose to canonicalize them to each other. We have 3+ ways of writing a matmul in linalg. We should be consistent with this argument and apply it everywhere if so.

What I would like to see in the RFC is:

  1. Show what "redundant" patterns we remove from having this canonicalization.
  2. Measure the impact of upstream patterns where every pattern looking for
    a transpose/broadcast/extract now has to also account for a shape_cast,
    because one edge case falls through to that.
  3. Show why this canonicalization aligns with the goals of canonicalizations.

@joker-eph
Copy link
Collaborator

joker-eph commented Oct 21, 2025

There is no justification in this PR that canonicalizing towards a single operation makes subsequent analysis and optimizations more effective.

"In general" having to match a single form for the same program is what makes subsequent analysis more effective: that is we can match a single "thing" instead of needing to match many. So this sentence as-is is a bit odd to me (but you're elaborating pretty well in the rest of your post, so don't take this as a rebuttal of your post as a whole but just a minor nitpick of this particular sentence).
That said it does not mean "one op", it just means "pick a direction and stick to it". For InstCombine for a while it was something like (from memory) "if a form has less number of operations or less SSA values in flight, then it's 'simpler'".
It's very valid to consider than one form is "more structured" than the other, but usually the conclusion is that this form is the more canonical one and the canonicalization should go the other direction. In then end: "one form is preferable to many" is the underlying mindset.

@banach-space
Copy link
Contributor

Hi Kunwar, thanks for your message - and apologies for the slow reply, it’s been a hectic week.

I’d like to catch up F2F at the LLVM Dev Meeting next week with as many folks as possible involved in the discussion. I’ll propose some concrete next steps after that.

@MaheshRavishankar
Copy link
Contributor

Hi Kunwar, thanks for your message - and apologies for the slow reply, it’s been a hectic week.

I’d like to catch up F2F at the LLVM Dev Meeting next week with as many folks as possible involved in the discussion. I’ll propose some concrete next steps after that.

Kunwar wont be there. Quinn and I are there and I am sure we can represent this view point there.

@dcaballe
Copy link
Contributor

I had written a reply but I agree that this needs high-bandwidth discussion. I'll also be around next week!

@banach-space
Copy link
Contributor

banach-space commented Nov 2, 2025

Hi everyone,

Here’s a quick summary from the LLVM Dev discussion (myself, @dcaballe , @MaheshRavishankar and @matthias-springer were present):

We’ve reviewed this extensively, and I’ve seen no evidence that this change would:

  • Drop information that cannot be recovered.
  • Break something that cannot be fixed (*).

Neither Diego nor I see any technical issues with this PR - every concern raised so far can be addressed without difficulty (**). As one specific point, let me return to Kunwar’s example, i.e. a pattern matching:

vector.transpose 1x2x2x8 -> 2x2x8x1
vector.contract { indexing_maps = [#map] }

to rewrite that as:

vector.contract { indexing_maps = [#transposed_map] }

Such a pattern should also match:

vector.shape_cast 8x8 -> 2x2x8x1
vector.contract { indexing_maps = [#map] }

Why? Because in this case vector.transpose and vector.shape_cast are semantically equivalent. A transformation that handles only vector.transpose is simply missing equivalent cases. This PR doesn’t make such transformations more or less correct - it only exposes unhandled but valid cases.

Ultimately, this PR is about agreeing on the canonical form. My view remains that operations not requiring data movement should canonicalise to vector.shape_cast, which helps solidify its semantics and keeps the model consistent. However, in the absence of a clear technical criterion (“energy function”) for preferring one form over another, much of this comes down to preference.

Both Diego and I support this change, but there is still disagreement on whether it qualifies as canonical (***). @ftynse 's talk on “Normal Forms for MLIR” from last week (****) offers a promising direction, but we’ll likely need to wait for Alex to implement it and propose an RFC. While that means progress here may pause for now, this discussion has helped clarify the underlying principles.

Thank you for the discussion and for all the input. If I missed or misinterpreted something, please post a clarification.

HTH

-Andrzej

(*) Mahesh offered to share an example, so I may revisit this point.
(**) This PR is deliberately limited in scope and only covers straightforward cases.
(***) We could still land it as a non-canononicalisation.
(****) Video should be available soon.

@MaheshRavishankar
Copy link
Contributor

Thank you @banach-space for summarizing the discussion. I think I vehemently agree with this

Ultimately, this PR is about agreeing on the canonical form. My view remains that operations not requiring data movement should canonicalise to vector.shape_cast, which helps solidify its semantics and keeps the model consistent. However, in the absence of a clear technical criterion (“energy function”) for preferring one form over another, much of this comes down to preference.

Both Diego and I support this change, but there is still disagreement on whether it qualifies as canonical (). @ftynse 's talk on “Normal Forms for MLIR” from last week (*) offers a promising direction, but we’ll likely need to wait for Alex to implement it and propose an RFC. While that means progress here may pause for now, this discussion has helped clarify the underlying principles.

Specifically

However, in the absence of a clear technical criterion (“energy function”) for preferring one form over another, much of this comes down to preference.

is a nice formulation of the problem.

@joker-eph
Copy link
Collaborator

@ftynse 's talk on “Normal Forms for MLIR” from last week (*) offers a promising direction,

I talked to @ftynse after the talk, and it seems that canonicalization is orthogonal to this concept of normal form, and I really don't see how canonicalization as a concept would go away.

However, in the absence of a clear technical criterion (“energy function”) for preferring one form over another, much of this comes down to preference.

This is not a rebuttal of canonicalization: we can try to apply principles to pick one or another form, but if at the end of everything two forms are completely equivalent, picking one is just a convention (e.g. "constants are moved to the right of commutative operations").

@dcaballe
Copy link
Contributor

dcaballe commented Nov 4, 2025

However, in the absence of a clear technical criterion (“energy function”) for preferring one form over another, much of this comes down to preference.

is a nice formulation of the problem.

I think we all agree that choosing a canonical form could be a matter of preference but that's not the issue we're discussing here. Blocking this proposal means we will NOT have a canonical form at all. This leaves developers having to handle all potential variants AFTER running the canonicalizer which is not a good state to be for anyone.

Adding to what Andrzej mentioned, I'd like to share some key ideas and points that were made during our F2F discussion:

  1. At this point, we haven't found a case that can't be converted from the shape cast form to the broadcast/transpose form.
  2. Reshape operations are not "low level" operations. They are part of high-level abstractions even at user level.
  3. In some projects like IREE, reshapes fall back to a not so optimal path as they are usually difficult to tile, fuse and optimize in general. Broadcast/transposes are handled more efficiently and, therefore, that form is preferred.
  4. This proposal does not block the use of the broadcast/transpose form. It only changes the canonical form. Using a shape cast form as the canonical form will also include providing conversion mechanisms to go from shape cast form to broadcast/transpose form.
  5. This aligns in practice with the idea of having multiple IR forms, along the lines of what Alex presented in his talk.
  6. As a consequence of (4), IREE would actually support more cases through the broadcast/transpose form and the more optimal path, as there will be a way to convert the shape cast form to the broadcast/transpose form.
  7. A counter argument to (4) is that this could add complexity for users who would need to reason about and convert between different IR forms.
  8. Retrieving the information needed to convert from the shape cast form to the broadcast/transpose form is straightforward. It requires comparing the input and output shapes of the shape cast operations and look for unit dimensions.
  9. A counter argument to (8) is that even if straightforward, this is considered "analysis" and makes the shape cast form not ideal to be the canonical form.

To reiterate our primary goal here: we're not trying to break any downstream projects but to improve the overall situation for all the Vector dialect users. The shape cast problem has been a recurring pain point as previous RFCs and discussions demonstrate. There have been multiple "energy functions" trying to improve the overall situation for everyone but they have been shut down without offering a sound alternative. Even at this point, I'm still failing to see a strong technical justification for the block.

@MaheshRavishankar
Copy link
Contributor

I was hoping that we had reached some sort of understanding, but the earlier response seems to get back to the same loop of discussion. I'd like to break out of it. We are in agreement about almost all of it, except for the aspect of canonicalization, and I will try once again to clarify what has been said earlier.

Adding to what Andrzej mentioned, I'd like to share some key ideas and points that were made during our F2F discussion:

  1. At this point, we haven't found a case that can't be converted from the shape cast form to the broadcast/transpose form.

I know, this is on me, I havent collected it yet, but I will try to dig this up.

  1. Reshape operations are not "low level" operations. They are part of high-level abstractions even at user level.

I disagree that it is. Just because they could come in some input program does not make it a high-level abstraction. Higher level abstractions are necessarily operations that carry "more information in a succinct manner". You could have Linalg ops and affine ops exist in the same input program. That does not make them at the same level of abstraction.

  1. In some projects like IREE, reshapes fall back to a not so optimal path as they are usually difficult to tile, fuse and optimize in general. Broadcast/transposes are handled more efficiently and, therefore, that form is preferred.

Again, I disagree. This is not about IREE, and not relevant to the conversation. The IREE project has enough people with the background needed to handle these things. If anything, folks working on IREE are bringing their perspective of building a full scale compiler using these constructs to bear here cause we see how more directly the impact of choices made upstream. The IREE project has ways to deal with any changes that dont align with IREE. This has nothing to do with tile + fuse since this is at vector level. To restate, overall we are on the same page in terms of shape_casts as the right abstraction to handle unit dim folding before lowering to LLVM.

  1. This proposal does not block the use of the broadcast/transpose form. It only changes the canonical form. Using a shape cast form as the canonical form will also include providing conversion mechanisms to go from shape cast form to broadcast/transpose form.

This is the crux of the problem. Canonicalization in MLIR are applied repeatedly and without control. That is how they are designed. There have been several talks by Matthias, and Alex that talk about how canonicalizations can silent change your compilers behavior just on an LLVM bump. Canonical form implies it is better always to chose this form over others, and when you cant prove that, you cant define the canonical form. "Just pick one" is not a robust enough solution long term (someone else can come along and pick something else, and there is no forcing function that disallows that). The fact that canonicalizations run everywhere without control + canonicalizations in MLIR have become a kitchen sink of patterns is really the problem.

  1. This aligns in practice with the idea of having multiple IR forms, along the lines of what Alex presented in his talk.

I dont think it does. The whole point of Alex's talk (and this was the motivation for normal forms) is that in a lot of cases, it is impossible to decide canonical forms.

  1. As a consequence of (4), IREE would actually support more cases through the broadcast/transpose form and the more optimal path, as there will be a way to convert the shape cast form to the broadcast/transpose form.

Again, lets leave IREE out of it. How IREE handles thing is completely irrelevant to this discussion.

  1. A counter argument to (4) is that this could add complexity for users who would need to reason about and convert between different IR forms.
  2. Retrieving the information needed to convert from the shape cast form to the broadcast/transpose form is straightforward. It requires comparing the input and output shapes of the shape cast operations and look for unit dimensions.
  3. A counter argument to (8) is that even if straightforward, this is considered "analysis" and makes the shape cast form not ideal to be the canonical form.

To reiterate our primary goal here: we're not trying to break any downstream projects but to improve the overall situation for all the Vector dialect users. The shape cast problem has been a recurring pain point as previous RFCs and discussions demonstrate. There have been multiple "energy functions" trying to improve the overall situation for everyone but they have been shut down without offering a sound alternative. Even at this point, I'm still failing to see a strong technical justification for the block.

I understand you are not trying to break downstream projects, and you see this as a way to improve overall situation. There is a disagreement that making this "canonical" does that. The only disagreement is the requirement that we have this as canonicalization. Lots of folks here disagree that it is.

We dont need to block progress on things just for this one point about picking a canonical form. Lot of things added here are definitely useful for many downstream projects. Particularly, as a sequence of transformations to convert vector to LLVM dialect, having a patterns that convert broadcasts, transposes, etc. to shape casts and folds them away is extremely valuable. We are just blocking here cause we want to make some representation canonical. Instead we could provide a sequence of transformations that will take a vector dialect program and convert it to LLVM, which is readily usable everywhere. I would suggest we go down that path instead of just going in circles trying to decide on a canonical form.

@MaheshRavishankar
Copy link
Contributor

@dcaballe and @banach-space see this post here https://discourse.llvm.org/t/rfc-update-to-general-design-section-of-operation-canonicalizations-in-mlir/79355?u=maheshravishankar . This talks about the how vector.transpose captures more information than a vector.shape_cast and how you cannot always go from shape_cast to transpose.

This is exactly the issue with treating vector.shape_cast as "canonical" representation for transposes and hoping that we can lift back to the original representation always.

@banach-space
Copy link
Contributor

Thanks @MaheshRavishankar , as promised I am returning to this after you've shared your example.

see this post here https://discourse.llvm.org/t/rfc-update-to-general-design-section-of-operation-canonicalizations-in-mlir/79355?u=maheshravishankar . This talks about the how vector.transpose captures more information than a vector.shape_cast and how you cannot always go from shape_cast to transpose.

I've extracted this repro as something representative (*):

func.func @transpose_to_shape_cast_1(%0 : vector<4x1x1xf32>) -> vector<1x4x1xf32> {
  %res = vector.transpose %0, [2, 0, 1] : vector<4x1x1xf32> to vector<1x4x1xf32>
  return %res : vector<1x4x1xf32>
}

// -----

func.func @transpose_to_shape_cast_2(%0 : vector<4x1x1xf32>) -> vector<1x4x1xf32> {
  %res = vector.transpose %0, [1, 0, 2] : vector<4x1x1xf32> to vector<1x4x1xf32>
  return %res : vector<1x4x1xf32>
}

QUESTION/COMMENT:

Aren't the examples above identical operations?

YES - LLVM example!
Lets try these:

# Canonicalize to vector.shape_cast, then lower.
$ mlir-opt  repro.mlir -canonicalize -test-lower-to-llvm --split-input-file
# Lower as vector.transpose.
$ mlir-opt  repro.mlir -test-lower-to-llvm --split-input-file

In both cases I get the following (testing using this PR):

module {
  llvm.func @transpose_to_shape_cast_1(%arg0: !llvm.array<4 x array<1 x vector<1xf32>>>) -> !llvm.array<1 x array<4 x vector<1xf32>>> {
    %0 = llvm.mlir.poison : !llvm.array<1 x array<4 x vector<1xf32>>>
    %1 = llvm.extractvalue %arg0[0, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
    %2 = llvm.insertvalue %1, %0[0, 0] : !llvm.array<1 x array<4 x vector<1xf32>>>
    %3 = llvm.extractvalue %arg0[1, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
    %4 = llvm.insertvalue %3, %2[0, 1] : !llvm.array<1 x array<4 x vector<1xf32>>>
    %5 = llvm.extractvalue %arg0[2, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
    %6 = llvm.insertvalue %5, %4[0, 2] : !llvm.array<1 x array<4 x vector<1xf32>>>
    %7 = llvm.extractvalue %arg0[3, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
    %8 = llvm.insertvalue %7, %6[0, 3] : !llvm.array<1 x array<4 x vector<1xf32>>>
    llvm.return %res : !llvm.array<1 x array<4 x vector<1xf32>>>
  }
}

// -----
module {
  llvm.func @transpose_to_shape_cast_2(%arg0: !llvm.array<4 x array<1 x vector<1xf32>>>) -> !llvm.array<1 x array<4 x vector<1xf32>>> {
    %0 = llvm.mlir.poison : !llvm.array<1 x array<4 x vector<1xf32>>>
    %1 = llvm.extractvalue %arg0[0, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
    %2 = llvm.insertvalue %1, %0[0, 0] : !llvm.array<1 x array<4 x vector<1xf32>>>
    %3 = llvm.extractvalue %arg0[1, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
    %4 = llvm.insertvalue %3, %2[0, 1] : !llvm.array<1 x array<4 x vector<1xf32>>>
    %5 = llvm.extractvalue %arg0[2, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
    %6 = llvm.insertvalue %5, %4[0, 2] : !llvm.array<1 x array<4 x vector<1xf32>>>
    %7 = llvm.extractvalue %arg0[3, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
    %res = llvm.insertvalue %7, %6[0, 3] : !llvm.array<1 x array<4 x vector<1xf32>>>
    llvm.return %res : !llvm.array<1 x array<4 x vector<1xf32>>>
  }
}

Note, %res == %arg0, which confirms that we are dealing with a NO-OP.

YES - SPIR-V example!
Lets try these:

# Canonicalize to vector.shape_cast, then lower.
$ mlir-opt  repro.mlir -canonicalize -test-convert-to-spirv --split-input-file
# Lower as vector.transpose.
$ mlir-opt  repro.mlir -test-convert-to-spirv --split-input-file

In both cases I get the following (testing using this PR):

module {
  func.func @transpose_to_shape_cast_1(%arg0: vector<1xf32>, %arg1: vector<1xf32>, %arg2: vector<1xf32>, %arg3: vector<1xf32>) -> (vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>) {
    return %arg0, %arg1, %arg2, %arg3 : vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>
  }
}

// -----
module {
  func.func @transpose_to_shape_cast_2(%arg0: vector<1xf32>, %arg1: vector<1xf32>, %arg2: vector<1xf32>, %arg3: vector<1xf32>) -> (vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>) {
    return %arg0, %arg1, %arg2, %arg3 : vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>
  }
}

SPIR-V makes it even clearer that we are dealing with a NO-OP 😅

FINAL THOUGHTS

I argue that in all cases we are dealing with one operation for which we have multiple names (vector.transpose [2, 0, 1] vs vector.transpose [1, 0, 2] vs vector.shape_cast). This discussion is merely trying to establish a single name for all of this.

I obviously might be missing something - please correct me know if that's the case. I am sharing this to make my mental model clear and to avoid confusion.

-Andrzej

(*) Please provide other examples if this does not capture what you had in mind.

@MaheshRavishankar
Copy link
Contributor

@banach-space I think I might not have communicated the intent of my example properly. This was more to show why vector.shape_cast is not the canonical representation of transposes with unit-dims, and that you cannot always recover the transpose from a shape_cast cause it looses information. So a vector.transpose carries more information than a vector.shape_cast.

It of course lowers to the same thing if you are just lowering to LLVM or SPIR-V. I have already agreed that while lowering to LLVM you should lower both these transposes to vector.shape_cast and then just cancel these out, cause at LLVM level this makes no difference. So it is perfectly valid to say that for the sequence of transformations lowering from vector dialect to LLVM without any further context specific analysis, all transpose with unit-dims should be lowered to shape_cast. In other words, the "normal form" for the set of passes that lower to vector dialect ops to LLVM prefers using shape_cast. I am on board with that. Canonical form is different. As shown above, the vector.transpose to vector.shape_cast and back is not always possible to do. You "lose" information when you lower to vector.shape_cast. That makes it non-canonical since a canonicalizer would not help you reach a "better state" of the program.

@dcaballe
Copy link
Contributor

dcaballe commented Nov 14, 2025

I think there is a misunderstanding about what we can expect from a canonical form. A canonical form should indeed allow us to convert to any equivalent representation. However, while these representations are semantically equivalent, the canonical form doesn't (and shouldn't need to) preserve information about which specific representation was part of the input IR.

To illustrate this with a simple but realistic example, consider different ways to represent "multiplication by 2" (I haven't checked but this is probably something that LLVM canonicalizes today to a single form):

  1. mul %a, 2
  2. shl %a, 1
  3. add %a, %a

If we choose option 2 as our canonical form, we can certainly convert to both options 1 and 3 from it when needed. What we can't do (and what isn't a requirement for canonical forms) is to automatically know which of these three forms the input IR had without any additional context.

Bringing this back to the vector.transpose example, the same principle applies. We can convert from the vector.shape_cast to any the two flavors of the vector.transpose operations but we have to decide which one.

The key point here is: if preserving exactly the original input representation is important for your use case, then canonicalization is not the right transformation to apply at that stage of your pipeline. That is not the right expectation to have for a canonical form.

@dcaballe
Copy link
Contributor

I found this "visualization" from Cursor quite illustrative:

Data Layout Diagram

Original Vector: vector<4x1x1xf32>

Let's assume we have 4 elements: [a, b, c, d]

Original layout (vector<4x1x1xf32>):
Dimension 0 (size=4): [a, b, c, d]
Dimension 1 (size=1): [ ]
Dimension 2 (size=1): [ ]

Conceptual 3D representation:
[[[a]], [[b]], [[c]], [[d]]]

Linear memory layout: [a, b, c, d]

Transpose with permutation [2, 0, 1]: vector<4x1x1xf32> → vector<1x4x1xf32>

After transpose [2, 0, 1]:
- Old dim 2 (size=1) → New dim 0 (size=1)
- Old dim 0 (size=4) → New dim 1 (size=4)  
- Old dim 1 (size=1) → New dim 2 (size=1)

Result layout (vector<1x4x1xf32>):
Dimension 0 (size=1): [ ]
Dimension 1 (size=4): [a, b, c, d]
Dimension 2 (size=1): [ ]

Conceptual 3D representation:
[[[a], [b], [c], [d]]]

Linear memory layout: [a, b, c, d]  (NO CHANGE!)

Transpose with permutation [1, 0, 2]: vector<4x1x1xf32> → vector<1x4x1xf32>

After transpose [1, 0, 2]:
- Old dim 1 (size=1) → New dim 0 (size=1)
- Old dim 0 (size=4) → New dim 1 (size=4)
- Old dim 2 (size=1) → New dim 2 (size=1)

Result layout (vector<1x4x1xf32>):
Dimension 0 (size=1): [ ]
Dimension 1 (size=4): [a, b, c, d]
Dimension 2 (size=1): [ ]

Conceptual 3D representation:
[[[a], [b], [c], [d]]]

Linear memory layout: [a, b, c, d]  (NO CHANGE!)

Visual Diagram

     Original: vector<4x1x1xf32>
     ┌─────────────────────────┐
     │ [[[a]], [[b]], [[c]], [[d]]] │
     └─────────────────────────┘
              │
              ├── transpose [2,0,1] ──┐
              │                       │
              └── transpose [1,0,2] ──┼─┐
                                      │ │
                                      ▼ ▼
                        Both produce: vector<1x4x1xf32>
                        ┌─────────────────────────┐
                        │ [[[a], [b], [c], [d]]]  │
                        └─────────────────────────┘

     Memory Layout (all cases): [a, b, c, d]

Could you think of an example where the actual permutation patterns ([2, 0, 1] vs [1, 0, 2]) are not redundant and lead to semantic differences in the IR?

@dcaballe
Copy link
Contributor

Could you think of an example where the actual permutation patterns ([2, 0, 1] vs [1, 0, 2]) are not redundant and lead to semantic differences in the IR?

Answering my own question, I can think of one use case: any kind of traversal that needs to track or propagate a property across one of the unit dimensions in the example wouldn't be able to do so with the vector.shape_cast form. The "data layout" information becomes ambiguous.

Conclusion 1 : We can’t canonicalize a transpose operation to a shape cast when multiple unit dimensions are transposed. The data layout or dimension mapping across the operation becomes ambiguous with the vector.shape_cast op.

Great, that’s progress! We have identified something technical specific. I suggest that we continue focusing on the technical aspects of the different IR forms. Could we continue this exercise? Could we come up with similar examples for:

  1. Non-ambiguous cases of transposes (i.e., only one unit dim is transposed)
  2. Broadcast (i.e., reshape-like broadcast adding unit dims)

@MaheshRavishankar
Copy link
Contributor

Great, that’s progress! We have identified something technical specific. I suggest that we continue focusing on the technical aspects of the different IR forms. Could we continue this exercise? Could we come up with similar examples for:

Non-ambiguous cases of transposes (i.e., only one unit dim is transposed)
Broadcast (i.e., reshape-like broadcast adding unit dims)

I am happy cursor was able to give you a better explanation of what I was trying to say all this while. Good to have reached this common state.

I think we did discuss previously, then stating that "certain" transpose/broadcasts are canonically shape_casts, and forcing them to then become shape_casts without control is now creating unnecessary complication in the definition of canonicalization. If some transformation is relying on following dimensions through broadcasts/transposes, now it has to look at a shape_cast, decide if this is "convertible to a transpose/broadcast" and then handle that appropriately. This does not seem like a great setup.

@banach-space
Copy link
Contributor

Thank you for the detailed discussion, @dcaballe - that was very helpful in clarifying the underlying issues.

I suggest that we continue focusing on the technical aspects of the different IR forms. Could we continue this exercise?

Agreed.

That’s one concrete example (*). @MaheshRavishankar, could you help us identify other specific cases so that we can better scope or constrain this change?

All in all, given the nuances discussed, I don’t see a specific blocker preventing this from being merged - or is there?

This does not seem like a great setup.

I think we may have to agree to disagree here. That said, as Vector maintainers, we’re committed to supporting all Vector users. If this change ends up causing issues for you, we’ll work with you to address them.

-Andrzej

(*) A transpose operation to a shape cast when multiple unit dimensions are transposed.

@MaheshRavishankar
Copy link
Contributor

All in all, given the nuances discussed, I don’t see a specific blocker preventing this from being merged - or is there?

We might be having different reads of the blocker. To me this discussion is uncovering more reasons why this change shouldnt be merged (this kind of thing is what I was saying would be an issue from the get go).

This does not seem like a great setup.

I think we may have to agree to disagree here. That said, as Vector maintainers, we’re committed to supporting all Vector users. If this change ends up causing issues for you, we’ll work with you to address them.

I want to re-iterate : this is not about my use case. We can find ways to work around things either way. So I am disagree-ing more with the approach here, rather than "this doesnt fit my use case".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment