@@ -1517,23 +1517,25 @@ LogicalResult arith::TruncIOp::verify() {
15171517// / Perform safe const propagation for truncf, i.e., only propagate if FP value
15181518// / can be represented without precision loss.
15191519OpFoldResult arith::TruncFOp::fold (FoldAdaptor adaptor) {
1520+ auto resElemType = cast<FloatType>(getElementTypeOrSelf (getType ()));
15201521 if (auto extOp = getOperand ().getDefiningOp <arith::ExtFOp>()) {
15211522 Value src = extOp.getIn ();
1522- Type srcType = getElementTypeOrSelf (src.getType ());
1523- Type dstType = getElementTypeOrSelf (getType ());
1524- // truncf(extf(a)) -> truncf(a)
1525- if (llvm::cast<FloatType>(srcType).getWidth () >
1526- llvm::cast<FloatType>(dstType).getWidth ()) {
1527- setOperand (src);
1528- return getResult ();
1529- }
1523+ auto srcType = cast<FloatType>(getElementTypeOrSelf (src.getType ()));
1524+ auto intermediateType = cast<FloatType>(getElementTypeOrSelf (extOp.getType ()));
1525+ // Check if the srcType is representable in the intermediateType
1526+ if (llvm::APFloatBase::isRepresentableBy (srcType.getFloatSemantics (), intermediateType.getFloatSemantics ())) {
1527+ // truncf(extf(a)) -> truncf(a)
1528+ if (srcType.getWidth () > resElemType.getWidth ()) {
1529+ setOperand (src);
1530+ return getResult ();
1531+ }
15301532
1531- // truncf(extf(a)) -> a
1532- if (srcType == dstType)
1533- return src;
1533+ // truncf(extf(a)) -> a
1534+ if (srcType == resElemType)
1535+ return src;
1536+ }
15341537 }
15351538
1536- auto resElemType = cast<FloatType>(getElementTypeOrSelf (getType ()));
15371539 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics ();
15381540 return constFoldCastOp<FloatAttr, FloatAttr>(
15391541 adaptor.getOperands (), getType (),
0 commit comments