@@ -702,56 +702,57 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
702702 // intrinsics.
703703 setOperationAction (ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
704704
705- // FP extload/truncstore is not legal in PTX. We need to expand all these.
706- for (auto FloatVTs :
707- {MVT::fp_valuetypes (), MVT::fp_fixedlen_vector_valuetypes ()}) {
708- for (MVT ValVT : FloatVTs) {
709- for (MVT MemVT : FloatVTs) {
710- setLoadExtAction (ISD::EXTLOAD, ValVT, MemVT, Expand);
711- setTruncStoreAction (ValVT, MemVT, Expand);
712- }
713- }
714- }
715-
716- // To improve CodeGen we'll legalize any-extend loads to zext loads. This is
717- // how they'll be lowered in ISel anyway, and by doing this a little earlier
718- // we allow for more DAG combine opportunities.
719- for (auto IntVTs :
720- {MVT::integer_valuetypes (), MVT::integer_fixedlen_vector_valuetypes ()})
721- for (MVT ValVT : IntVTs)
722- for (MVT MemVT : IntVTs)
723- if (isTypeLegal (ValVT))
724- setLoadExtAction (ISD::EXTLOAD, ValVT, MemVT, Custom);
705+ // Turn FP extload into load/fpextend
706+ setLoadExtAction (ISD::EXTLOAD, MVT::f32 , MVT::f16 , Expand);
707+ setLoadExtAction (ISD::EXTLOAD, MVT::f64 , MVT::f16 , Expand);
708+ setLoadExtAction (ISD::EXTLOAD, MVT::f32 , MVT::bf16 , Expand);
709+ setLoadExtAction (ISD::EXTLOAD, MVT::f64 , MVT::bf16 , Expand);
710+ setLoadExtAction (ISD::EXTLOAD, MVT::f64 , MVT::f32 , Expand);
711+ setLoadExtAction (ISD::EXTLOAD, MVT::v2f32, MVT::v2f16, Expand);
712+ setLoadExtAction (ISD::EXTLOAD, MVT::v2f64, MVT::v2f16, Expand);
713+ setLoadExtAction (ISD::EXTLOAD, MVT::v2f32, MVT::v2bf16, Expand);
714+ setLoadExtAction (ISD::EXTLOAD, MVT::v2f64, MVT::v2bf16, Expand);
715+ setLoadExtAction (ISD::EXTLOAD, MVT::v2f64, MVT::v2f32, Expand);
716+ setLoadExtAction (ISD::EXTLOAD, MVT::v4f32, MVT::v4f16, Expand);
717+ setLoadExtAction (ISD::EXTLOAD, MVT::v4f64, MVT::v4f16, Expand);
718+ setLoadExtAction (ISD::EXTLOAD, MVT::v4f32, MVT::v4bf16, Expand);
719+ setLoadExtAction (ISD::EXTLOAD, MVT::v4f64, MVT::v4bf16, Expand);
720+ setLoadExtAction (ISD::EXTLOAD, MVT::v4f64, MVT::v4f32, Expand);
721+ setLoadExtAction (ISD::EXTLOAD, MVT::v8f32, MVT::v8f16, Expand);
722+ setLoadExtAction (ISD::EXTLOAD, MVT::v8f64, MVT::v8f16, Expand);
723+ setLoadExtAction (ISD::EXTLOAD, MVT::v8f32, MVT::v8bf16, Expand);
724+ setLoadExtAction (ISD::EXTLOAD, MVT::v8f64, MVT::v8bf16, Expand);
725+ // Turn FP truncstore into trunc + store.
726+ // FIXME: vector types should also be expanded
727+ setTruncStoreAction (MVT::f32 , MVT::f16 , Expand);
728+ setTruncStoreAction (MVT::f64 , MVT::f16 , Expand);
729+ setTruncStoreAction (MVT::f32 , MVT::bf16 , Expand);
730+ setTruncStoreAction (MVT::f64 , MVT::bf16 , Expand);
731+ setTruncStoreAction (MVT::f64 , MVT::f32 , Expand);
732+ setTruncStoreAction (MVT::v2f32, MVT::v2f16, Expand);
733+ setTruncStoreAction (MVT::v2f32, MVT::v2bf16, Expand);
725734
726735 // PTX does not support load / store predicate registers
727- setOperationAction ({ISD::LOAD, ISD::STORE}, MVT::i1, Custom);
736+ setOperationAction (ISD::LOAD, MVT::i1, Custom);
737+ setOperationAction (ISD::STORE, MVT::i1, Custom);
738+
728739 for (MVT VT : MVT::integer_valuetypes ()) {
729- setLoadExtAction ({ISD::SEXTLOAD, ISD::ZEXTLOAD, ISD::EXTLOAD}, VT, MVT::i1,
730- Promote);
740+ setLoadExtAction (ISD::SEXTLOAD, VT, MVT::i1, Promote);
741+ setLoadExtAction (ISD::ZEXTLOAD, VT, MVT::i1, Promote);
742+ setLoadExtAction (ISD::EXTLOAD, VT, MVT::i1, Promote);
731743 setTruncStoreAction (VT, MVT::i1, Expand);
732744 }
733745
734- // Register custom handling for illegal type loads/stores. We'll try to custom
735- // lower almost all illegal types and logic in the lowering will discard cases
736- // we can't handle.
737- setOperationAction ({ISD::LOAD, ISD::STORE}, {MVT::i128 , MVT::f128 }, Custom);
738- for (MVT VT : MVT::fixedlen_vector_valuetypes ())
739- if (!isTypeLegal (VT) && VT.getStoreSizeInBits () <= 256 )
740- setOperationAction ({ISD::STORE, ISD::LOAD}, VT, Custom);
741-
742- // Custom legalization for LDU intrinsics.
743- // TODO: The logic to lower these is not very robust and we should rewrite it.
744- // Perhaps LDU should not be represented as an intrinsic at all.
745- setOperationAction (ISD::INTRINSIC_W_CHAIN, MVT::i8 , Custom);
746- for (MVT VT : MVT::fixedlen_vector_valuetypes ())
747- if (IsPTXVectorType (VT))
748- setOperationAction (ISD::INTRINSIC_W_CHAIN, VT, Custom);
749-
750746 setCondCodeAction ({ISD::SETNE, ISD::SETEQ, ISD::SETUGE, ISD::SETULE,
751747 ISD::SETUGT, ISD::SETULT, ISD::SETGT, ISD::SETLT,
752748 ISD::SETGE, ISD::SETLE},
753749 MVT::i1, Expand);
754750
751+ // expand extload of vector of integers.
752+ setLoadExtAction ({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::v2i16,
753+ MVT::v2i8, Expand);
754+ setTruncStoreAction (MVT::v2i16, MVT::v2i8, Expand);
755+
755756 // This is legal in NVPTX
756757 setOperationAction (ISD::ConstantFP, MVT::f64 , Legal);
757758 setOperationAction (ISD::ConstantFP, MVT::f32 , Legal);
@@ -766,12 +767,24 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
766767 // DEBUGTRAP can be lowered to PTX brkpt
767768 setOperationAction (ISD::DEBUGTRAP, MVT::Other, Legal);
768769
770+ // Register custom handling for vector loads/stores
771+ for (MVT VT : MVT::fixedlen_vector_valuetypes ())
772+ if (IsPTXVectorType (VT))
773+ setOperationAction ({ISD::LOAD, ISD::STORE, ISD::INTRINSIC_W_CHAIN}, VT,
774+ Custom);
775+
776+ setOperationAction ({ISD::LOAD, ISD::STORE, ISD::INTRINSIC_W_CHAIN},
777+ {MVT::i128 , MVT::f128 }, Custom);
778+
769779 // Support varargs.
770780 setOperationAction (ISD::VASTART, MVT::Other, Custom);
771781 setOperationAction (ISD::VAARG, MVT::Other, Custom);
772782 setOperationAction (ISD::VACOPY, MVT::Other, Expand);
773783 setOperationAction (ISD::VAEND, MVT::Other, Expand);
774784
785+ // Custom handling for i8 intrinsics
786+ setOperationAction (ISD::INTRINSIC_W_CHAIN, MVT::i8 , Custom);
787+
775788 setOperationAction ({ISD::ABS, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX},
776789 {MVT::i16 , MVT::i32 , MVT::i64 }, Legal);
777790
@@ -3079,14 +3092,39 @@ static void replaceLoadVector(SDNode *N, SelectionDAG &DAG,
30793092 SmallVectorImpl<SDValue> &Results,
30803093 const NVPTXSubtarget &STI);
30813094
3095+ SDValue NVPTXTargetLowering::LowerLOAD (SDValue Op, SelectionDAG &DAG) const {
3096+ if (Op.getValueType () == MVT::i1)
3097+ return LowerLOADi1 (Op, DAG);
3098+
3099+ EVT VT = Op.getValueType ();
3100+
3101+ if (NVPTX::isPackedVectorTy (VT)) {
3102+ // v2f32/v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to
3103+ // handle unaligned loads and have to handle it here.
3104+ LoadSDNode *Load = cast<LoadSDNode>(Op);
3105+ EVT MemVT = Load->getMemoryVT ();
3106+ if (!allowsMemoryAccessForAlignment (*DAG.getContext (), DAG.getDataLayout (),
3107+ MemVT, *Load->getMemOperand ())) {
3108+ SDValue Ops[2 ];
3109+ std::tie (Ops[0 ], Ops[1 ]) = expandUnalignedLoad (Load, DAG);
3110+ return DAG.getMergeValues (Ops, SDLoc (Op));
3111+ }
3112+ }
3113+
3114+ return SDValue ();
3115+ }
3116+
30823117// v = ld i1* addr
30833118// =>
30843119// v1 = ld i8* addr (-> i16)
30853120// v = trunc i16 to i1
3086- static SDValue lowerLOADi1 (LoadSDNode *LD, SelectionDAG &DAG) {
3087- SDLoc dl (LD);
3121+ SDValue NVPTXTargetLowering::LowerLOADi1 (SDValue Op, SelectionDAG &DAG) const {
3122+ SDNode *Node = Op.getNode ();
3123+ LoadSDNode *LD = cast<LoadSDNode>(Node);
3124+ SDLoc dl (Node);
30883125 assert (LD->getExtensionType () == ISD::NON_EXTLOAD);
3089- assert (LD->getValueType (0 ) == MVT::i1 && " Custom lowering for i1 load only" );
3126+ assert (Node->getValueType (0 ) == MVT::i1 &&
3127+ " Custom lowering for i1 load only" );
30903128 SDValue newLD = DAG.getExtLoad (ISD::ZEXTLOAD, dl, MVT::i16 , LD->getChain (),
30913129 LD->getBasePtr (), LD->getPointerInfo (),
30923130 MVT::i8 , LD->getAlign (),
@@ -3095,27 +3133,8 @@ static SDValue lowerLOADi1(LoadSDNode *LD, SelectionDAG &DAG) {
30953133 // The legalizer (the caller) is expecting two values from the legalized
30963134 // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
30973135 // in LegalizeDAG.cpp which also uses MergeValues.
3098- return DAG.getMergeValues ({result, LD->getChain ()}, dl);
3099- }
3100-
3101- SDValue NVPTXTargetLowering::LowerLOAD (SDValue Op, SelectionDAG &DAG) const {
3102- LoadSDNode *LD = cast<LoadSDNode>(Op);
3103-
3104- if (Op.getValueType () == MVT::i1)
3105- return lowerLOADi1 (LD, DAG);
3106-
3107- // To improve CodeGen we'll legalize any-extend loads to zext loads. This is
3108- // how they'll be lowered in ISel anyway, and by doing this a little earlier
3109- // we allow for more DAG combine opportunities.
3110- if (LD->getExtensionType () == ISD::EXTLOAD) {
3111- assert (LD->getValueType (0 ).isInteger () && LD->getMemoryVT ().isInteger () &&
3112- " Unexpected fpext-load" );
3113- return DAG.getExtLoad (ISD::ZEXTLOAD, SDLoc (Op), Op.getValueType (),
3114- LD->getChain (), LD->getBasePtr (), LD->getMemoryVT (),
3115- LD->getMemOperand ());
3116- }
3117-
3118- llvm_unreachable (" Unexpected custom lowering for load" );
3136+ SDValue Ops[] = { result, LD->getChain () };
3137+ return DAG.getMergeValues (Ops, dl);
31193138}
31203139
31213140SDValue NVPTXTargetLowering::LowerSTORE (SDValue Op, SelectionDAG &DAG) const {
@@ -3125,6 +3144,17 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
31253144 if (VT == MVT::i1)
31263145 return LowerSTOREi1 (Op, DAG);
31273146
3147+ // v2f32/v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to
3148+ // handle unaligned stores and have to handle it here.
3149+ if (NVPTX::isPackedVectorTy (VT) &&
3150+ !allowsMemoryAccessForAlignment (*DAG.getContext (), DAG.getDataLayout (),
3151+ VT, *Store->getMemOperand ()))
3152+ return expandUnalignedStore (Store, DAG);
3153+
3154+ // v2f16/v2bf16/v2i16 don't need special handling.
3155+ if (NVPTX::isPackedVectorTy (VT) && VT.is32BitVector ())
3156+ return SDValue ();
3157+
31283158 // Lower store of any other vector type, including v2f32 as we want to break
31293159 // it apart since this is not a widely-supported type.
31303160 return LowerSTOREVector (Op, DAG);
@@ -3980,8 +4010,14 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
39804010 case Intrinsic::nvvm_ldu_global_i:
39814011 case Intrinsic::nvvm_ldu_global_f:
39824012 case Intrinsic::nvvm_ldu_global_p: {
4013+ auto &DL = I.getDataLayout ();
39834014 Info.opc = ISD::INTRINSIC_W_CHAIN;
3984- Info.memVT = getValueType (I.getDataLayout (), I.getType ());
4015+ if (Intrinsic == Intrinsic::nvvm_ldu_global_i)
4016+ Info.memVT = getValueType (DL, I.getType ());
4017+ else if (Intrinsic == Intrinsic::nvvm_ldu_global_p)
4018+ Info.memVT = getPointerTy (DL);
4019+ else
4020+ Info.memVT = getValueType (DL, I.getType ());
39854021 Info.ptrVal = I.getArgOperand (0 );
39864022 Info.offset = 0 ;
39874023 Info.flags = MachineMemOperand::MOLoad;
0 commit comments