Skip to content

Commit d47886c

Browse files
committed
use resultType
Signed-off-by: James Newling <[email protected]>
1 parent eb5b9d7 commit d47886c

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5617,18 +5617,20 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56175617
if (getSource().getType() == getType())
56185618
return getSource();
56195619

5620+
VectorType resultType = getType();
5621+
56205622
// Canceling shape casts.
56215623
if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
56225624

56235625
// Only allows valid transitive folding (expand/collapse dimensions).
56245626
VectorType srcType = otherOp.getSource().getType();
5625-
if (getType() == srcType)
5627+
if (resultType == srcType)
56265628
return otherOp.getSource();
5627-
if (srcType.getRank() < getType().getRank()) {
5628-
if (!isValidShapeCast(srcType.getShape(), getType().getShape()))
5629+
if (srcType.getRank() < resultType.getRank()) {
5630+
if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
56295631
return {};
5630-
} else if (srcType.getRank() > getType().getRank()) {
5631-
if (!isValidShapeCast(getType().getShape(), srcType.getShape()))
5632+
} else if (srcType.getRank() > resultType.getRank()) {
5633+
if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
56325634
return {};
56335635
} else {
56345636
return {};
@@ -5639,14 +5641,14 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56395641

56405642
// Cancelling broadcast and shape cast ops.
56415643
if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5642-
if (bcastOp.getSourceType() == getType())
5644+
if (bcastOp.getSourceType() == resultType)
56435645
return bcastOp.getSource();
56445646
}
56455647

56465648
// shape_cast(constant) -> constant
56475649
if (auto splatAttr =
56485650
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
5649-
return DenseElementsAttr::get(getType(),
5651+
return DenseElementsAttr::get(resultType,
56505652
splatAttr.getSplatValue<Attribute>());
56515653
}
56525654

0 commit comments

Comments
 (0)