@@ -833,7 +833,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
833833 setTargetDAGCombine ({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
834834 ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
835835 ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND,
836- ISD::TRUNCATE, ISD::LOAD, ISD::BITCAST});
836+ ISD::TRUNCATE, ISD::LOAD, ISD::STORE, ISD:: BITCAST});
837837
838838 // setcc for f16x2 and bf16x2 needs special handling to prevent
839839 // legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -3091,10 +3091,10 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
30913091 if (Op.getValueType () == MVT::i1)
30923092 return LowerLOADi1 (Op, DAG);
30933093
3094- // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
3095- // unaligned loads and have to handle it here.
3094+ // v2f16/v2bf16/v2i16/v4i8/v2f32 are legal, so we can't rely on legalizer to
3095+ // handle unaligned loads and have to handle it here.
30963096 EVT VT = Op.getValueType ();
3097- if (Isv2x16VT (VT) || VT == MVT::v4i8) {
3097+ if (Isv2x16VT (VT) || VT == MVT::v4i8 || VT == MVT::v2f32 ) {
30983098 LoadSDNode *Load = cast<LoadSDNode>(Op);
30993099 EVT MemVT = Load->getMemoryVT ();
31003100 if (!allowsMemoryAccessForAlignment (*DAG.getContext (), DAG.getDataLayout (),
@@ -3138,15 +3138,15 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
31383138 if (VT == MVT::i1)
31393139 return LowerSTOREi1 (Op, DAG);
31403140
3141- // v2f16 is legal, so we can't rely on legalizer to handle unaligned
3142- // stores and have to handle it here.
3143- if ((Isv2x16VT (VT) || VT == MVT::v4i8) &&
3141+ // v2f16/v2bf16/v2i16/v4i8/v2f32 are legal, so we can't rely on legalizer to
3142+ // handle unaligned stores and have to handle it here.
3143+ if ((Isv2x16VT (VT) || VT == MVT::v4i8 || VT == MVT::v2f32 ) &&
31443144 !allowsMemoryAccessForAlignment (*DAG.getContext (), DAG.getDataLayout (),
31453145 VT, *Store->getMemOperand ()))
31463146 return expandUnalignedStore (Store, DAG);
31473147
3148- // v2f16, v2bf16 and v2i16 don't need special handling.
3149- if (Isv2x16VT (VT) || VT == MVT::v4i8)
3148+ // v2f16/ v2bf16/ v2i16/v4i8/v2f32 don't need special handling.
3149+ if (Isv2x16VT (VT) || VT == MVT::v4i8 || VT == MVT::v2f32 )
31503150 return SDValue ();
31513151
31523152 if (VT.isVector ())
@@ -3155,8 +3155,8 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
31553155 return SDValue ();
31563156}
31573157
3158- SDValue
3159- NVPTXTargetLowering::LowerSTOREVector (SDValue Op, SelectionDAG &DAG) const {
3158+ static SDValue convertVectorStore (SDValue Op, SelectionDAG &DAG,
3159+ const SmallVectorImpl<SDValue> &Elements) {
31603160 SDNode *N = Op.getNode ();
31613161 SDValue Val = N->getOperand (1 );
31623162 SDLoc DL (N);
@@ -3223,6 +3223,8 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
32233223 SDValue SubVector = DAG.getBuildVector (EltVT, DL, SubVectorElts);
32243224 Ops.push_back (SubVector);
32253225 }
3226+ } else if (!Elements.empty ()) {
3227+ Ops.insert (Ops.end (), Elements.begin (), Elements.end ());
32263228 } else {
32273229 for (unsigned i = 0 ; i < NumElts; ++i) {
32283230 SDValue ExtVal = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
@@ -3240,10 +3242,19 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
32403242 DAG.getMemIntrinsicNode (Opcode, DL, DAG.getVTList (MVT::Other), Ops,
32413243 MemSD->getMemoryVT (), MemSD->getMemOperand ());
32423244
3243- // return DCI.CombineTo(N, NewSt, true);
32443245 return NewSt;
32453246}
32463247
3248+ // Default variant where we don't pass in elements.
3249+ static SDValue convertVectorStore (SDValue Op, SelectionDAG &DAG) {
3250+ return convertVectorStore (Op, DAG, SmallVector<SDValue>{});
3251+ }
3252+
3253+ SDValue NVPTXTargetLowering::LowerSTOREVector (SDValue Op,
3254+ SelectionDAG &DAG) const {
3255+ return convertVectorStore (Op, DAG);
3256+ }
3257+
32473258// st i1 v, addr
32483259// =>
32493260// v1 = zxt v to i16
@@ -5402,6 +5413,9 @@ static SDValue PerformStoreCombineHelper(SDNode *N,
54025413 // -->
54035414 // StoreRetvalV2 {a, b}
54045415 // likewise for V2 -> V4 case
5416+ //
5417+ // We also handle target-independent stores, which require us to first
5418+ // convert to StoreV2.
54055419
54065420 std::optional<NVPTXISD::NodeType> NewOpcode;
54075421 switch (N->getOpcode ()) {
@@ -5427,8 +5441,8 @@ static SDValue PerformStoreCombineHelper(SDNode *N,
54275441 SDValue CurrentOp = N->getOperand (I);
54285442 if (CurrentOp->getOpcode () == ISD::BUILD_VECTOR) {
54295443 assert (CurrentOp.getValueType () == MVT::v2f32);
5430- NewOps.push_back (CurrentOp.getNode ()-> getOperand (0 ));
5431- NewOps.push_back (CurrentOp.getNode ()-> getOperand (1 ));
5444+ NewOps.push_back (CurrentOp.getOperand (0 ));
5445+ NewOps.push_back (CurrentOp.getOperand (1 ));
54325446 } else {
54335447 NewOps.clear ();
54345448 break ;
@@ -6199,6 +6213,18 @@ static SDValue PerformBITCASTCombine(SDNode *N,
61996213 return SDValue ();
62006214}
62016215
6216+ static SDValue PerformStoreCombine (SDNode *N,
6217+ TargetLowering::DAGCombinerInfo &DCI) {
6218+ // check if the store'd value can be scalarized
6219+ SDValue StoredVal = N->getOperand (1 );
6220+ if (StoredVal.getValueType () == MVT::v2f32 &&
6221+ StoredVal.getOpcode () == ISD::BUILD_VECTOR) {
6222+ SmallVector<SDValue> Elements (StoredVal->op_values ());
6223+ return convertVectorStore (SDValue (N, 0 ), DCI.DAG , Elements);
6224+ }
6225+ return SDValue ();
6226+ }
6227+
62026228SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
62036229 DAGCombinerInfo &DCI) const {
62046230 CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
@@ -6228,6 +6254,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
62286254 case NVPTXISD::LoadParam:
62296255 case NVPTXISD::LoadParamV2:
62306256 return PerformLoadCombine (N, DCI);
6257+ case ISD::STORE:
6258+ return PerformStoreCombine (N, DCI);
62316259 case NVPTXISD::StoreParam:
62326260 case NVPTXISD::StoreParamV2:
62336261 case NVPTXISD::StoreParamV4:
0 commit comments