@@ -3464,24 +3464,44 @@ SDValue AArch64TargetLowering::LowerFP_TO_INT_SAT(SDValue Op,
3464
3464
uint64_t DstWidth = DstVT.getScalarSizeInBits();
3465
3465
assert(SatWidth <= DstWidth && "Saturation width cannot exceed result width");
3466
3466
3467
- // TODO: Saturate to SatWidth explicitly.
3468
- if (SatWidth != DstWidth)
3467
+ // In the absence of FP16 support, promote f16 to f32 and saturate the result.
3468
+ if (SrcVT == MVT::f16 && !Subtarget->hasFullFP16()) {
3469
+ SrcVal = DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), MVT::f32, SrcVal);
3470
+ SrcVT = MVT::f32;
3471
+ } else if (SrcVT != MVT::f64 && SrcVT != MVT::f32 && SrcVT != MVT::f16)
3469
3472
return SDValue();
3470
3473
3471
- // In the absence of FP16 support, promote f16 to f32, like LowerFP_TO_INT().
3472
- if (SrcVT == MVT::f16 && !Subtarget->hasFullFP16())
3473
- return DAG.getNode(Op.getOpcode(), SDLoc(Op), DstVT,
3474
- DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), MVT::f32, SrcVal),
3475
- Op.getOperand(1));
3476
-
3474
+ SDLoc DL(Op);
3477
3475
// Cases that we can emit directly.
3478
3476
if ((SrcVT == MVT::f64 || SrcVT == MVT::f32 ||
3479
3477
(SrcVT == MVT::f16 && Subtarget->hasFullFP16())) &&
3480
- (DstVT == MVT::i64 || DstVT == MVT::i32))
3481
- return Op;
3478
+ DstVT == SatVT && (DstVT == MVT::i64 || DstVT == MVT::i32))
3479
+ return DAG.getNode(Op.getOpcode(), DL, DstVT, SrcVal,
3480
+ DAG.getValueType(DstVT));
3481
+
3482
+ // Otherwise we emit a cvt that saturates to a higher BW, and saturate the
3483
+ // result. This is only valid if the legal cvt is larger than the saturate
3484
+ // width.
3485
+ if (DstWidth < SatWidth)
3486
+ return SDValue();
3482
3487
3483
- // For all other cases, fall back on the expanded form.
3484
- return SDValue();
3488
+ SDValue NativeCvt =
3489
+ DAG.getNode(Op.getOpcode(), DL, DstVT, SrcVal, DAG.getValueType(DstVT));
3490
+ SDValue Sat;
3491
+ if (Op.getOpcode() == ISD::FP_TO_SINT_SAT) {
3492
+ SDValue MinC = DAG.getConstant(
3493
+ APInt::getSignedMaxValue(SatWidth).sextOrSelf(DstWidth), DL, DstVT);
3494
+ SDValue Min = DAG.getNode(ISD::SMIN, DL, DstVT, NativeCvt, MinC);
3495
+ SDValue MaxC = DAG.getConstant(
3496
+ APInt::getSignedMinValue(SatWidth).sextOrSelf(DstWidth), DL, DstVT);
3497
+ Sat = DAG.getNode(ISD::SMAX, DL, DstVT, Min, MaxC);
3498
+ } else {
3499
+ SDValue MinC = DAG.getConstant(
3500
+ APInt::getAllOnesValue(SatWidth).zextOrSelf(DstWidth), DL, DstVT);
3501
+ Sat = DAG.getNode(ISD::UMIN, DL, DstVT, NativeCvt, MinC);
3502
+ }
3503
+
3504
+ return DAG.getNode(ISD::TRUNCATE, DL, DstVT, Sat);
3485
3505
}
3486
3506
3487
3507
SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
0 commit comments