42
42
#include " llvm/ADT/SmallVector.h"
43
43
#include " llvm/ADT/StringSet.h"
44
44
#include " llvm/ADT/TypeSwitch.h"
45
+ #include " llvm/Support/Casting.h"
45
46
46
47
#include < cassert>
47
48
#include < cstdint>
@@ -5611,18 +5612,20 @@ LogicalResult ShapeCastOp::verify() {
5611
5612
}
5612
5613
5613
5614
OpFoldResult ShapeCastOp::fold (FoldAdaptor adaptor) {
5615
+
5614
5616
// No-op shape cast.
5615
- if (getSource ().getType () == getResult (). getType ())
5617
+ if (getSource ().getType () == getType ())
5616
5618
return getSource ();
5617
5619
5620
+ VectorType resultType = getType ();
5621
+
5618
5622
// Canceling shape casts.
5619
5623
if (auto otherOp = getSource ().getDefiningOp <ShapeCastOp>()) {
5620
- if (getResult ().getType () == otherOp.getSource ().getType ())
5621
- return otherOp.getSource ();
5622
5624
5623
- // Only allows valid transitive folding.
5624
- VectorType srcType = llvm::cast<VectorType>(otherOp.getSource ().getType ());
5625
- VectorType resultType = llvm::cast<VectorType>(getResult ().getType ());
5625
+ // Only allows valid transitive folding (expand/collapse dimensions).
5626
+ VectorType srcType = otherOp.getSource ().getType ();
5627
+ if (resultType == srcType)
5628
+ return otherOp.getSource ();
5626
5629
if (srcType.getRank () < resultType.getRank ()) {
5627
5630
if (!isValidShapeCast (srcType.getShape (), resultType.getShape ()))
5628
5631
return {};
@@ -5632,43 +5635,32 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
5632
5635
} else {
5633
5636
return {};
5634
5637
}
5635
-
5636
5638
setOperand (otherOp.getSource ());
5637
5639
return getResult ();
5638
5640
}
5639
5641
5640
5642
// Cancelling broadcast and shape cast ops.
5641
5643
if (auto bcastOp = getSource ().getDefiningOp <BroadcastOp>()) {
5642
- if (bcastOp.getSourceType () == getType () )
5644
+ if (bcastOp.getSourceType () == resultType )
5643
5645
return bcastOp.getSource ();
5644
5646
}
5645
5647
5648
+ // shape_cast(constant) -> constant
5649
+ if (auto splatAttr =
5650
+ llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource ())) {
5651
+ return DenseElementsAttr::get (resultType,
5652
+ splatAttr.getSplatValue <Attribute>());
5653
+ }
5654
+
5655
+ // shape_cast(poison) -> poison
5656
+ if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource ())) {
5657
+ return ub::PoisonAttr::get (getContext ());
5658
+ }
5659
+
5646
5660
return {};
5647
5661
}
5648
5662
5649
5663
namespace {
5650
- // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
5651
- class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
5652
- public:
5653
- using OpRewritePattern::OpRewritePattern;
5654
-
5655
- LogicalResult matchAndRewrite (ShapeCastOp shapeCastOp,
5656
- PatternRewriter &rewriter) const override {
5657
- auto constantOp =
5658
- shapeCastOp.getSource ().getDefiningOp <arith::ConstantOp>();
5659
- if (!constantOp)
5660
- return failure ();
5661
- // Only handle splat for now.
5662
- auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue ());
5663
- if (!dense)
5664
- return failure ();
5665
- auto newAttr =
5666
- DenseElementsAttr::get (llvm::cast<VectorType>(shapeCastOp.getType ()),
5667
- dense.getSplatValue <Attribute>());
5668
- rewriter.replaceOpWithNewOp <arith::ConstantOp>(shapeCastOp, newAttr);
5669
- return success ();
5670
- }
5671
- };
5672
5664
5673
5665
// / Helper function that computes a new vector type based on the input vector
5674
5666
// / type by removing the trailing one dims:
@@ -5828,8 +5820,9 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
5828
5820
5829
5821
void ShapeCastOp::getCanonicalizationPatterns (RewritePatternSet &results,
5830
5822
MLIRContext *context) {
5831
- results.add <ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
5832
- ShapeCastBroadcastFolder>(context);
5823
+ results
5824
+ .add <ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
5825
+ context);
5833
5826
}
5834
5827
5835
5828
// ===----------------------------------------------------------------------===//
0 commit comments