Skip to content

Commit cd85f5d

Browse files
authored
[mlir] canonicalizer: shape_cast(poison) -> poison (llvm#133988)
Based on the ShapeCastConstantFolder, this pattern replaces %0 = ub.poison : vector<2x3xf32> %1 = vector.shape_cast %0 vector<2x3xf32> to vector<6xf32> with %1 = ub.poison : vector<6xf32> --------- Signed-off-by: James Newling <[email protected]>
1 parent a922525 commit cd85f5d

File tree

2 files changed

+39
-32
lines changed

2 files changed

+39
-32
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 25 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#include "llvm/ADT/SmallVector.h"
4343
#include "llvm/ADT/StringSet.h"
4444
#include "llvm/ADT/TypeSwitch.h"
45+
#include "llvm/Support/Casting.h"
4546

4647
#include <cassert>
4748
#include <cstdint>
@@ -5611,18 +5612,20 @@ LogicalResult ShapeCastOp::verify() {
56115612
}
56125613

56135614
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
5615+
56145616
// No-op shape cast.
5615-
if (getSource().getType() == getResult().getType())
5617+
if (getSource().getType() == getType())
56165618
return getSource();
56175619

5620+
VectorType resultType = getType();
5621+
56185622
// Canceling shape casts.
56195623
if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
5620-
if (getResult().getType() == otherOp.getSource().getType())
5621-
return otherOp.getSource();
56225624

5623-
// Only allows valid transitive folding.
5624-
VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
5625-
VectorType resultType = llvm::cast<VectorType>(getResult().getType());
5625+
// Only allows valid transitive folding (expand/collapse dimensions).
5626+
VectorType srcType = otherOp.getSource().getType();
5627+
if (resultType == srcType)
5628+
return otherOp.getSource();
56265629
if (srcType.getRank() < resultType.getRank()) {
56275630
if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
56285631
return {};
@@ -5632,43 +5635,32 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56325635
} else {
56335636
return {};
56345637
}
5635-
56365638
setOperand(otherOp.getSource());
56375639
return getResult();
56385640
}
56395641

56405642
// Cancelling broadcast and shape cast ops.
56415643
if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5642-
if (bcastOp.getSourceType() == getType())
5644+
if (bcastOp.getSourceType() == resultType)
56435645
return bcastOp.getSource();
56445646
}
56455647

5648+
// shape_cast(constant) -> constant
5649+
if (auto splatAttr =
5650+
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
5651+
return DenseElementsAttr::get(resultType,
5652+
splatAttr.getSplatValue<Attribute>());
5653+
}
5654+
5655+
// shape_cast(poison) -> poison
5656+
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
5657+
return ub::PoisonAttr::get(getContext());
5658+
}
5659+
56465660
return {};
56475661
}
56485662

56495663
namespace {
5650-
// Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
5651-
class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
5652-
public:
5653-
using OpRewritePattern::OpRewritePattern;
5654-
5655-
LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5656-
PatternRewriter &rewriter) const override {
5657-
auto constantOp =
5658-
shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
5659-
if (!constantOp)
5660-
return failure();
5661-
// Only handle splat for now.
5662-
auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
5663-
if (!dense)
5664-
return failure();
5665-
auto newAttr =
5666-
DenseElementsAttr::get(llvm::cast<VectorType>(shapeCastOp.getType()),
5667-
dense.getSplatValue<Attribute>());
5668-
rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
5669-
return success();
5670-
}
5671-
};
56725664

56735665
/// Helper function that computes a new vector type based on the input vector
56745666
/// type by removing the trailing one dims:
@@ -5828,8 +5820,9 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
58285820

58295821
void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
58305822
MLIRContext *context) {
5831-
results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
5832-
ShapeCastBroadcastFolder>(context);
5823+
results
5824+
.add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
5825+
context);
58335826
}
58345827

58355828
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,6 +1167,20 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
11671167

11681168
// -----
11691169

1170+
// CHECK-LABEL: shape_cast_poison
1171+
// CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32>
1172+
// CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32>
1173+
// CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
1174+
func.func @shape_cast_poison() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
1175+
%poison = ub.poison : vector<5x4x2xf32>
1176+
%poison_1 = ub.poison : vector<12x2xi32>
1177+
%0 = vector.shape_cast %poison : vector<5x4x2xf32> to vector<20x2xf32>
1178+
%1 = vector.shape_cast %poison_1 : vector<12x2xi32> to vector<3x4x2xi32>
1179+
return %0, %1 : vector<20x2xf32>, vector<3x4x2xi32>
1180+
}
1181+
1182+
// -----
1183+
11701184
// CHECK-LABEL: extract_strided_constant
11711185
// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<2x13x3xi32>
11721186
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<12x2xf32>

0 commit comments

Comments
 (0)