@@ -828,7 +828,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
828828 setTargetDAGCombine ({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
829829 ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
830830 ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND,
831- ISD::TRUNCATE, ISD::LOAD, ISD::BITCAST});
831+ ISD::TRUNCATE, ISD::LOAD, ISD::STORE, ISD:: BITCAST});
832832
833833 // setcc for f16x2 and bf16x2 needs special handling to prevent
834834 // legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -2992,10 +2992,10 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
29922992 if (Op.getValueType () == MVT::i1)
29932993 return LowerLOADi1 (Op, DAG);
29942994
2995- // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
2996- // unaligned loads and have to handle it here.
2995+ // v2f16/v2bf16/v2i16/v4i8/v2f32 are legal, so we can't rely on legalizer to
2996+ // handle unaligned loads and have to handle it here.
29972997 EVT VT = Op.getValueType ();
2998- if (Isv2x16VT (VT) || VT == MVT::v4i8) {
2998+ if (Isv2x16VT (VT) || VT == MVT::v4i8 || VT == MVT::v2f32 ) {
29992999 LoadSDNode *Load = cast<LoadSDNode>(Op);
30003000 EVT MemVT = Load->getMemoryVT ();
30013001 if (!allowsMemoryAccessForAlignment (*DAG.getContext (), DAG.getDataLayout (),
@@ -3039,15 +3039,15 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
30393039 if (VT == MVT::i1)
30403040 return LowerSTOREi1 (Op, DAG);
30413041
3042- // v2f16 is legal, so we can't rely on legalizer to handle unaligned
3043- // stores and have to handle it here.
3044- if ((Isv2x16VT (VT) || VT == MVT::v4i8) &&
3042+ // v2f16/v2bf16/v2i16/v4i8/v2f32 are legal, so we can't rely on legalizer to
3043+ // handle unaligned stores and have to handle it here.
3044+ if ((Isv2x16VT (VT) || VT == MVT::v4i8 || VT == MVT::v2f32 ) &&
30453045 !allowsMemoryAccessForAlignment (*DAG.getContext (), DAG.getDataLayout (),
30463046 VT, *Store->getMemOperand ()))
30473047 return expandUnalignedStore (Store, DAG);
30483048
3049- // v2f16, v2bf16 and v2i16 don't need special handling.
3050- if (Isv2x16VT (VT) || VT == MVT::v4i8)
3049+ // v2f16/ v2bf16/ v2i16/v4i8/v2f32 don't need special handling.
3050+ if (Isv2x16VT (VT) || VT == MVT::v4i8 || VT == MVT::v2f32 )
30513051 return SDValue ();
30523052
30533053 if (VT.isVector ())
@@ -3056,8 +3056,8 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
30563056 return SDValue ();
30573057}
30583058
3059- SDValue
3060- NVPTXTargetLowering::LowerSTOREVector (SDValue Op, SelectionDAG &DAG) const {
3059+ static SDValue convertVectorStore (SDValue Op, SelectionDAG &DAG,
3060+ const SmallVectorImpl<SDValue> &Elements) {
30613061 SDNode *N = Op.getNode ();
30623062 SDValue Val = N->getOperand (1 );
30633063 SDLoc DL (N);
@@ -3124,6 +3124,8 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
31243124 SDValue SubVector = DAG.getBuildVector (EltVT, DL, SubVectorElts);
31253125 Ops.push_back (SubVector);
31263126 }
3127+ } else if (!Elements.empty ()) {
3128+ Ops.insert (Ops.end (), Elements.begin (), Elements.end ());
31273129 } else {
31283130 for (unsigned i = 0 ; i < NumElts; ++i) {
31293131 SDValue ExtVal = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
@@ -3141,10 +3143,19 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
31413143 DAG.getMemIntrinsicNode (Opcode, DL, DAG.getVTList (MVT::Other), Ops,
31423144 MemSD->getMemoryVT (), MemSD->getMemOperand ());
31433145
3144- // return DCI.CombineTo(N, NewSt, true);
31453146 return NewSt;
31463147}
31473148
3149+ // Default variant where we don't pass in elements.
3150+ static SDValue convertVectorStore (SDValue Op, SelectionDAG &DAG) {
3151+ return convertVectorStore (Op, DAG, SmallVector<SDValue>{});
3152+ }
3153+
3154+ SDValue NVPTXTargetLowering::LowerSTOREVector (SDValue Op,
3155+ SelectionDAG &DAG) const {
3156+ return convertVectorStore (Op, DAG);
3157+ }
3158+
31483159// st i1 v, addr
31493160// =>
31503161// v1 = zxt v to i16
@@ -5289,6 +5300,9 @@ static SDValue PerformStoreCombineHelper(SDNode *N,
52895300 // -->
52905301 // StoreRetvalV2 {a, b}
52915302 // likewise for V2 -> V4 case
5303+ //
5304+ // We also handle target-independent stores, which require us to first
5305+ // convert to StoreV2.
52925306
52935307 std::optional<NVPTXISD::NodeType> NewOpcode;
52945308 switch (N->getOpcode ()) {
@@ -5314,8 +5328,8 @@ static SDValue PerformStoreCombineHelper(SDNode *N,
53145328 SDValue CurrentOp = N->getOperand (I);
53155329 if (CurrentOp->getOpcode () == ISD::BUILD_VECTOR) {
53165330 assert (CurrentOp.getValueType () == MVT::v2f32);
5317- NewOps.push_back (CurrentOp.getNode ()-> getOperand (0 ));
5318- NewOps.push_back (CurrentOp.getNode ()-> getOperand (1 ));
5331+ NewOps.push_back (CurrentOp.getOperand (0 ));
5332+ NewOps.push_back (CurrentOp.getOperand (1 ));
53195333 } else {
53205334 NewOps.clear ();
53215335 break ;
@@ -6086,6 +6100,18 @@ static SDValue PerformBITCASTCombine(SDNode *N,
60866100 return SDValue ();
60876101}
60886102
6103+ static SDValue PerformStoreCombine (SDNode *N,
6104+ TargetLowering::DAGCombinerInfo &DCI) {
6105+ // check if the store'd value can be scalarized
6106+ SDValue StoredVal = N->getOperand (1 );
6107+ if (StoredVal.getValueType () == MVT::v2f32 &&
6108+ StoredVal.getOpcode () == ISD::BUILD_VECTOR) {
6109+ SmallVector<SDValue> Elements (StoredVal->op_values ());
6110+ return convertVectorStore (SDValue (N, 0 ), DCI.DAG , Elements);
6111+ }
6112+ return SDValue ();
6113+ }
6114+
60896115SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
60906116 DAGCombinerInfo &DCI) const {
60916117 CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
@@ -6115,6 +6141,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
61156141 case NVPTXISD::LoadParam:
61166142 case NVPTXISD::LoadParamV2:
61176143 return PerformLoadCombine (N, DCI);
6144+ case ISD::STORE:
6145+ return PerformStoreCombine (N, DCI);
61186146 case NVPTXISD::StoreParam:
61196147 case NVPTXISD::StoreParamV2:
61206148 case NVPTXISD::StoreParamV4:
0 commit comments