Skip to content

Conversation

@newling
Copy link
Contributor

@newling newling commented Apr 11, 2025

This is a follow-up after #133988

  1. The above added the canonicalizer shape_cast(poison) -> poison. This PR adds broadcast(poison) -> poison and transpose(poison) -> poison

  2. [NFC] Reviewer @Groverkss noted that canonicalizers should always be folders where possible ([mlir] canonicalizer: shape_cast(poison) -> poison  #133988 (comment)). This PR moves 2 canonicalizers to folders.

  3. [NFC] added missing ----- between tests in canonicalize.mlir

@llvmbot
Copy link
Member

llvmbot commented Apr 11, 2025

@llvm/pr-subscribers-mlir

Author: James Newling (newling)

Changes

This is a follow-up after #133988

  1. The above added the canonicalizer shape_cast(poison) -> poison. This PR adds broadcast(poison) -> poison and transpose(poison) -> poison

  2. [NFC] Reviewer @Groverkss noted that canonicalizers should always be folders where possible ([mlir] canonicalizer: shape_cast(poison) -> poison  #133988 (comment)). This PR moves 2 canonicalizers to folders.

  3. [NFC] added missing ----- between tests in canonicalize.mlir


Full diff: https://github.com/llvm/llvm-project/pull/135435.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+65-102)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+40)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5324e38fa7d25..71c0ccf8fb1ca 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2590,6 +2590,8 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
   }
   if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
     return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
+  if (llvm::dyn_cast<ub::PoisonAttr>(adaptor.getSource()))
+    return ub::PoisonAttr::get(getContext());
   return {};
 }
 
@@ -3717,6 +3719,59 @@ OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
     return getVector();
   if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
     return getResult();
+
+  Attribute foldInput = adaptor.getVector();
+  if (!foldInput) {
+    return {};
+  }
+
+  // rewrite : ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
+  if (auto splat = llvm::dyn_cast<SplatElementsAttr>(foldInput))
+    DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
+
+  // rewrite ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
+  if (auto dense = llvm::dyn_cast<DenseElementsAttr>(foldInput)) {
+    // TODO: Handle non-unit strides when they become available.
+    if (hasNonUnitStrides())
+      return {};
+
+    Value sourceVector = getVector();
+    auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
+    ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
+    SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
+
+    VectorType sliceVecTy = getType();
+    ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
+    int64_t sliceRank = sliceVecTy.getRank();
+
+    // Expand offsets and sizes to match the vector rank.
+    SmallVector<int64_t, 4> offsets(sliceRank, 0);
+    copy(getI64SubArray(getOffsets()), offsets.begin());
+
+    SmallVector<int64_t, 4> sizes(sourceShape);
+    copy(getI64SubArray(getSizes()), sizes.begin());
+
+    // Calculate the slice elements by enumerating all slice positions and
+    // linearizing them. The enumeration order is lexicographic which yields a
+    // sequence of monotonically increasing linearized position indices.
+    auto denseValuesBegin = dense.value_begin<Attribute>();
+    SmallVector<Attribute> sliceValues;
+    sliceValues.reserve(sliceVecTy.getNumElements());
+    SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
+    do {
+      int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
+      assert(linearizedPosition < sourceVecTy.getNumElements() &&
+             "Invalid index");
+      sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
+    } while (
+        succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
+
+    assert(static_cast<int64_t>(sliceValues.size()) ==
+               sliceVecTy.getNumElements() &&
+           "Invalid number of slice elements");
+    return DenseElementsAttr::get(sliceVecTy, sliceValues);
+  }
+
   return {};
 }
 
@@ -3781,98 +3836,6 @@ class StridedSliceConstantMaskFolder final
   }
 };
 
-// Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
-class StridedSliceSplatConstantFolder final
-    : public OpRewritePattern<ExtractStridedSliceOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
-                                PatternRewriter &rewriter) const override {
-    // Return if 'ExtractStridedSliceOp' operand is not defined by a splat
-    // ConstantOp.
-    Value sourceVector = extractStridedSliceOp.getVector();
-    Attribute vectorCst;
-    if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
-      return failure();
-
-    auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
-    if (!splat)
-      return failure();
-
-    auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(),
-                                          splat.getSplatValue<Attribute>());
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
-                                                   newAttr);
-    return success();
-  }
-};
-
-// Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
-// ConstantOp.
-class StridedSliceNonSplatConstantFolder final
-    : public OpRewritePattern<ExtractStridedSliceOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
-                                PatternRewriter &rewriter) const override {
-    // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
-    // ConstantOp.
-    Value sourceVector = extractStridedSliceOp.getVector();
-    Attribute vectorCst;
-    if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
-      return failure();
-
-    // The splat case is handled by `StridedSliceSplatConstantFolder`.
-    auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
-    if (!dense || dense.isSplat())
-      return failure();
-
-    // TODO: Handle non-unit strides when they become available.
-    if (extractStridedSliceOp.hasNonUnitStrides())
-      return failure();
-
-    auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
-    ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
-    SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
-
-    VectorType sliceVecTy = extractStridedSliceOp.getType();
-    ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
-    int64_t sliceRank = sliceVecTy.getRank();
-
-    // Expand offsets and sizes to match the vector rank.
-    SmallVector<int64_t, 4> offsets(sliceRank, 0);
-    copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());
-
-    SmallVector<int64_t, 4> sizes(sourceShape);
-    copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
-
-    // Calculate the slice elements by enumerating all slice positions and
-    // linearizing them. The enumeration order is lexicographic which yields a
-    // sequence of monotonically increasing linearized position indices.
-    auto denseValuesBegin = dense.value_begin<Attribute>();
-    SmallVector<Attribute> sliceValues;
-    sliceValues.reserve(sliceVecTy.getNumElements());
-    SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
-    do {
-      int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
-      assert(linearizedPosition < sourceVecTy.getNumElements() &&
-             "Invalid index");
-      sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
-    } while (
-        succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
-
-    assert(static_cast<int64_t>(sliceValues.size()) ==
-               sliceVecTy.getNumElements() &&
-           "Invalid number of slice elements");
-    auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues);
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
-                                                   newAttr);
-    return success();
-  }
-};
-
 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
 // BroadcastOp(ExtractStrideSliceOp).
 class StridedSliceBroadcast final
@@ -4016,8 +3979,7 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
   // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
-  results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
-              StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
+  results.add<StridedSliceConstantMaskFolder, StridedSliceBroadcast,
               StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
       context);
 }
@@ -5654,10 +5616,8 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   // shape_cast(constant) -> constant
   if (auto splatAttr =
-          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
-    return DenseElementsAttr::get(resultType,
-                                  splatAttr.getSplatValue<Attribute>());
-  }
+          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
+    return splatAttr.reshape(getType());
 
   // shape_cast(poison) -> poison
   if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
@@ -6001,10 +5961,13 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
 
 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
   // Eliminate splat constant transpose ops.
-  if (auto attr =
-          llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
-    if (attr.isSplat())
-      return attr.reshape(getResultVectorType());
+  if (auto splat =
+          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
+    return splat.reshape(getResultVectorType());
+
+  // Eliminate poison transpose ops.
+  if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
+    return ub::PoisonAttr::get(getContext());
 
   // Eliminate identity transpose ops. This happens when the dimensions of the
   // input vector remain in their original order after the transpose operation.
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a6d82b85777b0..73021a194e5bc 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1121,6 +1121,8 @@ func.func @bitcast_folding(%I1: vector<4x8xf32>, %I2: vector<2xi32>) -> (vector<
   return %0, %2 : vector<4x8xf32>, vector<2xi32>
 }
 
+// -----
+
 // CHECK-LABEL: func @bitcast_f16_to_f32
 //              bit pattern: 0x40004000
 //       CHECK-DAG: %[[CST1:.+]] = arith.constant dense<2.00390625> : vector<4xf32>
@@ -1135,6 +1137,8 @@ func.func @bitcast_f16_to_f32() -> (vector<4xf32>, vector<4xf32>) {
   return %cast0, %cast1: vector<4xf32>, vector<4xf32>
 }
 
+// -----
+
 // CHECK-LABEL: func @bitcast_i8_to_i32
 //              bit pattern: 0xA0A0A0A0
 //       CHECK-DAG: %[[CST1:.+]] = arith.constant dense<-1600085856> : vector<4xi32>
@@ -1176,6 +1180,28 @@ func.func @broadcast_folding2() -> vector<4x16xi32> {
 
 // -----
 
+// CHECK-LABEL: broadcast_poison
+//       CHECK:  %[[POISON:.*]] = ub.poison : vector<4x6xi8>
+//       CHECK:  return %[[POISON]] : vector<4x6xi8>
+func.func @broadcast_poison() -> vector<4x6xi8> {
+  %poison = ub.poison : vector<6xi8>
+  %broadcast = vector.broadcast %poison : vector<6xi8> to vector<4x6xi8>
+  return %broadcast : vector<4x6xi8>
+}
+
+// -----
+
+// CHECK-LABEL:  broadcast_splat_constant
+//       CHECK:  %[[CONST:.*]] = arith.constant dense<1> : vector<4x6xi8>
+//       CHECK:  return %[[CONST]] : vector<4x6xi8>
+func.func @broadcast_splat_constant() -> vector<4x6xi8> {
+  %cst = arith.constant dense<1> : vector<6xi8>
+  %broadcast = vector.broadcast %cst : vector<6xi8> to vector<4x6xi8>
+  return %broadcast : vector<4x6xi8>
+}
+
+// -----
+
 // CHECK-LABEL: @fold_consecutive_broadcasts(
 //  CHECK-SAME:                              %[[ARG0:.*]]: i32
 //       CHECK: %[[RESULT:.*]] = vector.broadcast %[[ARG0]] : i32 to vector<4x16xi32>
@@ -1710,6 +1736,7 @@ func.func @vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32
 }
 
 // -----
+
 // CHECK-LABEL:   func.func @vector_multi_reduction_scalable(
 // CHECK-SAME:     %[[VAL_0:.*]]: vector<1x[4]x1xf32>,
 // CHECK-SAME:     %[[VAL_1:.*]]: vector<1x[4]xf32>,
@@ -2251,6 +2278,8 @@ func.func @transpose_splat_constant() -> vector<8x4xf32> {
   return %0 : vector<8x4xf32>
 }
 
+// -----
+
 // CHECK-LABEL:   func @transpose_splat2(
 // CHECK-SAME:                           %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
 // CHECK:           %[[VAL_1:.*]] = vector.splat %[[VAL_0]] : vector<3x4xf32>
@@ -2264,6 +2293,17 @@ func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {
 
 // -----
 
+// CHECK-LABEL: transpose_poison
+//       CHECK:  %[[POISON:.*]] = ub.poison : vector<4x6xi8>
+//       CHECK:  return %[[POISON]] : vector<4x6xi8>
+func.func @transpose_poison() -> vector<4x6xi8> {
+  %poison = ub.poison : vector<6x4xi8>
+  %transpose = vector.transpose %poison, [1, 0] : vector<6x4xi8> to vector<4x6xi8>
+  return %transpose : vector<4x6xi8>
+}
+
+// -----
+
 // CHECK-LABEL: func.func @insert_1d_constant
 //   CHECK-DAG: %[[ACST:.*]] = arith.constant dense<[9, 1, 2]> : vector<3xi32>
 //   CHECK-DAG: %[[BCST:.*]] = arith.constant dense<[0, 9, 2]> : vector<3xi32>

@llvmbot
Copy link
Member

llvmbot commented Apr 11, 2025

@llvm/pr-subscribers-mlir-vector

Author: James Newling (newling)

Changes

This is a follow-up after #133988

  1. The above added the canonicalizer shape_cast(poison) -&gt; poison. This PR adds broadcast(poison) -&gt; poison and transpose(poison) -&gt; poison

  2. [NFC] Reviewer @Groverkss noted that canonicalizers should always be folders where possible ([mlir] canonicalizer: shape_cast(poison) -> poison  #133988 (comment)). This PR moves 2 canonicalizers to folders.

  3. [NFC] added missing ----- between tests in canonicalize.mlir


Full diff: https://github.com/llvm/llvm-project/pull/135435.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+65-102)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+40)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5324e38fa7d25..71c0ccf8fb1ca 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2590,6 +2590,8 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
   }
   if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
     return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
+  if (llvm::dyn_cast<ub::PoisonAttr>(adaptor.getSource()))
+    return ub::PoisonAttr::get(getContext());
   return {};
 }
 
@@ -3717,6 +3719,59 @@ OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
     return getVector();
   if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
     return getResult();
+
+  Attribute foldInput = adaptor.getVector();
+  if (!foldInput) {
+    return {};
+  }
+
+  // rewrite : ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
+  if (auto splat = llvm::dyn_cast<SplatElementsAttr>(foldInput))
+    DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
+
+  // rewrite ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
+  if (auto dense = llvm::dyn_cast<DenseElementsAttr>(foldInput)) {
+    // TODO: Handle non-unit strides when they become available.
+    if (hasNonUnitStrides())
+      return {};
+
+    Value sourceVector = getVector();
+    auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
+    ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
+    SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
+
+    VectorType sliceVecTy = getType();
+    ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
+    int64_t sliceRank = sliceVecTy.getRank();
+
+    // Expand offsets and sizes to match the vector rank.
+    SmallVector<int64_t, 4> offsets(sliceRank, 0);
+    copy(getI64SubArray(getOffsets()), offsets.begin());
+
+    SmallVector<int64_t, 4> sizes(sourceShape);
+    copy(getI64SubArray(getSizes()), sizes.begin());
+
+    // Calculate the slice elements by enumerating all slice positions and
+    // linearizing them. The enumeration order is lexicographic which yields a
+    // sequence of monotonically increasing linearized position indices.
+    auto denseValuesBegin = dense.value_begin<Attribute>();
+    SmallVector<Attribute> sliceValues;
+    sliceValues.reserve(sliceVecTy.getNumElements());
+    SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
+    do {
+      int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
+      assert(linearizedPosition < sourceVecTy.getNumElements() &&
+             "Invalid index");
+      sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
+    } while (
+        succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
+
+    assert(static_cast<int64_t>(sliceValues.size()) ==
+               sliceVecTy.getNumElements() &&
+           "Invalid number of slice elements");
+    return DenseElementsAttr::get(sliceVecTy, sliceValues);
+  }
+
   return {};
 }
 
@@ -3781,98 +3836,6 @@ class StridedSliceConstantMaskFolder final
   }
 };
 
-// Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
-class StridedSliceSplatConstantFolder final
-    : public OpRewritePattern<ExtractStridedSliceOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
-                                PatternRewriter &rewriter) const override {
-    // Return if 'ExtractStridedSliceOp' operand is not defined by a splat
-    // ConstantOp.
-    Value sourceVector = extractStridedSliceOp.getVector();
-    Attribute vectorCst;
-    if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
-      return failure();
-
-    auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
-    if (!splat)
-      return failure();
-
-    auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(),
-                                          splat.getSplatValue<Attribute>());
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
-                                                   newAttr);
-    return success();
-  }
-};
-
-// Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
-// ConstantOp.
-class StridedSliceNonSplatConstantFolder final
-    : public OpRewritePattern<ExtractStridedSliceOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
-                                PatternRewriter &rewriter) const override {
-    // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
-    // ConstantOp.
-    Value sourceVector = extractStridedSliceOp.getVector();
-    Attribute vectorCst;
-    if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
-      return failure();
-
-    // The splat case is handled by `StridedSliceSplatConstantFolder`.
-    auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
-    if (!dense || dense.isSplat())
-      return failure();
-
-    // TODO: Handle non-unit strides when they become available.
-    if (extractStridedSliceOp.hasNonUnitStrides())
-      return failure();
-
-    auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
-    ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
-    SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
-
-    VectorType sliceVecTy = extractStridedSliceOp.getType();
-    ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
-    int64_t sliceRank = sliceVecTy.getRank();
-
-    // Expand offsets and sizes to match the vector rank.
-    SmallVector<int64_t, 4> offsets(sliceRank, 0);
-    copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());
-
-    SmallVector<int64_t, 4> sizes(sourceShape);
-    copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
-
-    // Calculate the slice elements by enumerating all slice positions and
-    // linearizing them. The enumeration order is lexicographic which yields a
-    // sequence of monotonically increasing linearized position indices.
-    auto denseValuesBegin = dense.value_begin<Attribute>();
-    SmallVector<Attribute> sliceValues;
-    sliceValues.reserve(sliceVecTy.getNumElements());
-    SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
-    do {
-      int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
-      assert(linearizedPosition < sourceVecTy.getNumElements() &&
-             "Invalid index");
-      sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
-    } while (
-        succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
-
-    assert(static_cast<int64_t>(sliceValues.size()) ==
-               sliceVecTy.getNumElements() &&
-           "Invalid number of slice elements");
-    auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues);
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
-                                                   newAttr);
-    return success();
-  }
-};
-
 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
 // BroadcastOp(ExtractStrideSliceOp).
 class StridedSliceBroadcast final
@@ -4016,8 +3979,7 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
   // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
-  results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
-              StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
+  results.add<StridedSliceConstantMaskFolder, StridedSliceBroadcast,
               StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
       context);
 }
@@ -5654,10 +5616,8 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   // shape_cast(constant) -> constant
   if (auto splatAttr =
-          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
-    return DenseElementsAttr::get(resultType,
-                                  splatAttr.getSplatValue<Attribute>());
-  }
+          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
+    return splatAttr.reshape(getType());
 
   // shape_cast(poison) -> poison
   if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
@@ -6001,10 +5961,13 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
 
 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
   // Eliminate splat constant transpose ops.
-  if (auto attr =
-          llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
-    if (attr.isSplat())
-      return attr.reshape(getResultVectorType());
+  if (auto splat =
+          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
+    return splat.reshape(getResultVectorType());
+
+  // Eliminate poison transpose ops.
+  if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
+    return ub::PoisonAttr::get(getContext());
 
   // Eliminate identity transpose ops. This happens when the dimensions of the
   // input vector remain in their original order after the transpose operation.
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a6d82b85777b0..73021a194e5bc 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1121,6 +1121,8 @@ func.func @bitcast_folding(%I1: vector<4x8xf32>, %I2: vector<2xi32>) -> (vector<
   return %0, %2 : vector<4x8xf32>, vector<2xi32>
 }
 
+// -----
+
 // CHECK-LABEL: func @bitcast_f16_to_f32
 //              bit pattern: 0x40004000
 //       CHECK-DAG: %[[CST1:.+]] = arith.constant dense<2.00390625> : vector<4xf32>
@@ -1135,6 +1137,8 @@ func.func @bitcast_f16_to_f32() -> (vector<4xf32>, vector<4xf32>) {
   return %cast0, %cast1: vector<4xf32>, vector<4xf32>
 }
 
+// -----
+
 // CHECK-LABEL: func @bitcast_i8_to_i32
 //              bit pattern: 0xA0A0A0A0
 //       CHECK-DAG: %[[CST1:.+]] = arith.constant dense<-1600085856> : vector<4xi32>
@@ -1176,6 +1180,28 @@ func.func @broadcast_folding2() -> vector<4x16xi32> {
 
 // -----
 
+// CHECK-LABEL: broadcast_poison
+//       CHECK:  %[[POISON:.*]] = ub.poison : vector<4x6xi8>
+//       CHECK:  return %[[POISON]] : vector<4x6xi8>
+func.func @broadcast_poison() -> vector<4x6xi8> {
+  %poison = ub.poison : vector<6xi8>
+  %broadcast = vector.broadcast %poison : vector<6xi8> to vector<4x6xi8>
+  return %broadcast : vector<4x6xi8>
+}
+
+// -----
+
+// CHECK-LABEL:  broadcast_splat_constant
+//       CHECK:  %[[CONST:.*]] = arith.constant dense<1> : vector<4x6xi8>
+//       CHECK:  return %[[CONST]] : vector<4x6xi8>
+func.func @broadcast_splat_constant() -> vector<4x6xi8> {
+  %cst = arith.constant dense<1> : vector<6xi8>
+  %broadcast = vector.broadcast %cst : vector<6xi8> to vector<4x6xi8>
+  return %broadcast : vector<4x6xi8>
+}
+
+// -----
+
 // CHECK-LABEL: @fold_consecutive_broadcasts(
 //  CHECK-SAME:                              %[[ARG0:.*]]: i32
 //       CHECK: %[[RESULT:.*]] = vector.broadcast %[[ARG0]] : i32 to vector<4x16xi32>
@@ -1710,6 +1736,7 @@ func.func @vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32
 }
 
 // -----
+
 // CHECK-LABEL:   func.func @vector_multi_reduction_scalable(
 // CHECK-SAME:     %[[VAL_0:.*]]: vector<1x[4]x1xf32>,
 // CHECK-SAME:     %[[VAL_1:.*]]: vector<1x[4]xf32>,
@@ -2251,6 +2278,8 @@ func.func @transpose_splat_constant() -> vector<8x4xf32> {
   return %0 : vector<8x4xf32>
 }
 
+// -----
+
 // CHECK-LABEL:   func @transpose_splat2(
 // CHECK-SAME:                           %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
 // CHECK:           %[[VAL_1:.*]] = vector.splat %[[VAL_0]] : vector<3x4xf32>
@@ -2264,6 +2293,17 @@ func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {
 
 // -----
 
+// CHECK-LABEL: transpose_poison
+//       CHECK:  %[[POISON:.*]] = ub.poison : vector<4x6xi8>
+//       CHECK:  return %[[POISON]] : vector<4x6xi8>
+func.func @transpose_poison() -> vector<4x6xi8> {
+  %poison = ub.poison : vector<6x4xi8>
+  %transpose = vector.transpose %poison, [1, 0] : vector<6x4xi8> to vector<4x6xi8>
+  return %transpose : vector<4x6xi8>
+}
+
+// -----
+
 // CHECK-LABEL: func.func @insert_1d_constant
 //   CHECK-DAG: %[[ACST:.*]] = arith.constant dense<[9, 1, 2]> : vector<3xi32>
 //   CHECK-DAG: %[[BCST:.*]] = arith.constant dense<[0, 9, 2]> : vector<3xi32>

@banach-space
Copy link
Contributor

Thanks for working on this !

[NFC] Reviewer @Groverkss noted that canonicalizers should always be folders where possible ([mlir] canonicalizer: shape_cast(poison) -> poison #133988 (comment)). This PR moves 2 canonicalizers to folders.

Could this be a separate PR? I'd rather have every commit implement exactly one thing.

@newling
Copy link
Contributor Author

newling commented Apr 14, 2025

Thanks for working on this !

[NFC] Reviewer @Groverkss noted that canonicalizers should always be folders where possible ([mlir] canonicalizer: shape_cast(poison) -> poison #133988 (comment)). This PR moves 2 canonicalizers to folders.

Could this be a separate PR? I'd rather have every commit implement exactly one thing.

Sure, I'm happy to split this up. I'll close this and post the separate changes soon.

@newling newling closed this Apr 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants