Skip to content

Commit eb5b9d7

Browse files
committed
simplify folding of const/poison
Signed-off-by: James Newling <[email protected]>
1 parent 99ed635 commit eb5b9d7

File tree

1 file changed

+18
-20
lines changed

1 file changed

+18
-20
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#include "llvm/ADT/SmallVector.h"
4343
#include "llvm/ADT/StringSet.h"
4444
#include "llvm/ADT/TypeSwitch.h"
45+
#include "llvm/Support/Casting.h"
4546

4647
#include <cassert>
4748
#include <cstdint>
@@ -5611,28 +5612,27 @@ LogicalResult ShapeCastOp::verify() {
56115612
}
56125613

56135614
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
5615+
56145616
// No-op shape cast.
5615-
if (getSource().getType() == getResult().getType())
5617+
if (getSource().getType() == getType())
56165618
return getSource();
56175619

56185620
// Canceling shape casts.
56195621
if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
5620-
if (getResult().getType() == otherOp.getSource().getType())
5621-
return otherOp.getSource();
56225622

5623-
// Only allows valid transitive folding.
5624-
VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
5625-
VectorType resultType = llvm::cast<VectorType>(getResult().getType());
5626-
if (srcType.getRank() < resultType.getRank()) {
5627-
if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
5623+
// Only allows valid transitive folding (expand/collapse dimensions).
5624+
VectorType srcType = otherOp.getSource().getType();
5625+
if (getType() == srcType)
5626+
return otherOp.getSource();
5627+
if (srcType.getRank() < getType().getRank()) {
5628+
if (!isValidShapeCast(srcType.getShape(), getType().getShape()))
56285629
return {};
5629-
} else if (srcType.getRank() > resultType.getRank()) {
5630-
if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
5630+
} else if (srcType.getRank() > getType().getRank()) {
5631+
if (!isValidShapeCast(getType().getShape(), srcType.getShape()))
56315632
return {};
56325633
} else {
56335634
return {};
56345635
}
5635-
56365636
setOperand(otherOp.getSource());
56375637
return getResult();
56385638
}
@@ -5643,17 +5643,15 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56435643
return bcastOp.getSource();
56445644
}
56455645

5646-
// Replace shape_cast(arith.constant) with arith.constant. Currently only
5647-
// handles splat constants.
5648-
if (auto constantOp = getSource().getDefiningOp<arith::ConstantOp>()) {
5649-
if (auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue())) {
5650-
return DenseElementsAttr::get(cast<VectorType>(getType()),
5651-
dense.getSplatValue<Attribute>());
5652-
}
5646+
// shape_cast(constant) -> constant
5647+
if (auto splatAttr =
5648+
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
5649+
return DenseElementsAttr::get(getType(),
5650+
splatAttr.getSplatValue<Attribute>());
56535651
}
56545652

5655-
// Replace shape_cast(poison) with poison.
5656-
if (getSource().getDefiningOp<ub::PoisonOp>()) {
5653+
// shape_cast(poison) -> poison
5654+
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
56575655
return ub::PoisonAttr::get(getContext());
56585656
}
56595657

0 commit comments

Comments
 (0)