@@ -825,11 +825,33 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
825825 Custom);
826826 }
827827
828- // Avoid true 16 instruction
829- if (!Subtarget->hasTrue16BitInsts() || !Subtarget->useRealTrue16Insts()) {
830- // MVT::v2i16 for src type check in foldToSaturated
831- // MVT::v2i8 for dst type check in CustomLowerNode
832- setOperationAction(ISD::TRUNCATE_SSAT_U, {MVT::v2i16, MVT::v2i8}, Custom);
828+ // True 16 instruction is current not supported
829+ // FIXME: Add support for true 16 when supported
830+ if (!(Subtarget->hasTrue16BitInsts() && Subtarget->useRealTrue16Insts())) {
831+ // MVT::vNi16 for src type check in foldToSaturated
832+ // MVT::vNi8 for dst type check in CustomLowerNode
833+ setOperationAction(ISD::TRUNCATE_SSAT_U,
834+ {
835+ MVT::v2i16,
836+ MVT::v4i16,
837+ MVT::v8i16,
838+ MVT::v16i16,
839+ MVT::v32i16,
840+ MVT::v64i16,
841+ MVT::v128i16,
842+ MVT::v256i16,
843+ MVT::v512i16,
844+ MVT::v2i8,
845+ MVT::v4i8,
846+ MVT::v8i8,
847+ MVT::v16i8,
848+ MVT::v32i8,
849+ MVT::v64i8,
850+ MVT::v128i8,
851+ MVT::v256i8,
852+ MVT::v512i8,
853+ },
854+ Custom);
833855 }
834856 }
835857
@@ -1990,8 +2012,10 @@ bool SITargetLowering::isTypeDesirableForOp(unsigned Op, EVT VT) const {
19902012 if (VT == MVT::i1 && Op == ISD::SETCC)
19912013 return false;
19922014
1993- // v2i8 is illegal and only allowed in specific cases
1994- if (VT == MVT::v2i8 && Op == ISD::TRUNCATE_SSAT_U)
2015+ // Special case for vNi8 handling where N is even
2016+ if (Op == ISD::TRUNCATE_SSAT_U && VT.isVector() &&
2017+ VT.getVectorElementType() == MVT::i8 &&
2018+ ((VT.getVectorNumElements() & 1) == 0))
19952019 return true;
19962020
19972021 return TargetLowering::isTypeDesirableForOp(Op, VT);
@@ -6628,10 +6652,39 @@ void SITargetLowering::ReplaceNodeResults(SDNode *N,
66286652 }
66296653 case ISD::TRUNCATE_SSAT_U: {
66306654 SDLoc SL(N);
6631- SDValue Op =
6632- DAG.getNode(AMDGPUISD::SAT_PK_CAST, SL, MVT::i16, N->getOperand(0));
6633- Op = DAG.getNode(ISD::BITCAST, SL, MVT::v2i8, Op);
6634- Results.push_back(Op);
6655+ SDValue Src = N->getOperand(0);
6656+ EVT SrcVT = Src.getValueType();
6657+ EVT DstVT = N->getValueType(0);
6658+
6659+ assert(SrcVT.isVector() && DstVT.isVector());
6660+
6661+ unsigned EleNo = SrcVT.getVectorNumElements();
6662+ assert(EleNo == DstVT.getVectorNumElements());
6663+
6664+ if (EleNo == 2) {
6665+ SDValue Op =
6666+ DAG.getNode(AMDGPUISD::SAT_PK_CAST, SL, MVT::i16, N->getOperand(0));
6667+ Op = DAG.getNode(ISD::BITCAST, SL, N->getValueType(0), Op);
6668+ Results.push_back(Op);
6669+ } else {
6670+ // Must be even number
6671+ assert((EleNo & 1) == 0);
6672+ SmallVector<SDValue> DstPairs;
6673+ EVT SrcEleVT = SrcVT.getVectorElementType();
6674+ EVT DstEleVT = DstVT.getVectorElementType();
6675+ EVT SrcPairVT = EVT::getVectorVT(*DAG.getContext(), SrcEleVT, 2);
6676+ EVT DstPairVT = EVT::getVectorVT(*DAG.getContext(), DstEleVT, 2);
6677+ for (unsigned i = 0; i + 1 < EleNo; i = i + 2) {
6678+ SDValue SrcPair = DAG.getNode(ISD::EXTRACT_SUBVECTOR, SL, SrcPairVT,
6679+ Src, DAG.getConstant(i, SL, MVT::i32));
6680+ SDValue SatPk =
6681+ DAG.getNode(AMDGPUISD::SAT_PK_CAST, SL, MVT::i16, SrcPair);
6682+ SDValue DstPair = DAG.getNode(ISD::BITCAST, SL, DstPairVT, SatPk);
6683+ DstPairs.push_back(DstPair);
6684+ }
6685+ SDValue Op = DAG.getNode(ISD::CONCAT_VECTORS, SL, DstVT, DstPairs);
6686+ Results.push_back(Op);
6687+ }
66356688 break;
66366689 }
66376690 default:
0 commit comments